mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-03 15:55:44 +00:00
Compare commits
67 Commits
eugene/pin
...
eugene/aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0ea4b8b25 | ||
|
|
d45d2294fc | ||
|
|
9933c6c44c | ||
|
|
f869131743 | ||
|
|
db116dd4cd | ||
|
|
47cdc8873c | ||
|
|
5a8c0ab5c5 | ||
|
|
ff4d1d150a | ||
|
|
cd3a57602d | ||
|
|
5601d60df0 | ||
|
|
1cd9fb6444 | ||
|
|
6eb5a1c5e6 | ||
|
|
71dee4c885 | ||
|
|
b3cbff741e | ||
|
|
00c3a2d3d2 | ||
|
|
ba61d64c04 | ||
|
|
0396570c82 | ||
|
|
cf6bca3da7 | ||
|
|
95d7e87a68 | ||
|
|
62b936dc7c | ||
|
|
23e1ef8f4a | ||
|
|
4108862154 | ||
|
|
c9cbfc9d51 | ||
|
|
96f3fc2d58 | ||
|
|
486d772184 | ||
|
|
29ddf63abb | ||
|
|
35318007d8 | ||
|
|
900cf6ead3 | ||
|
|
beec2d99b9 | ||
|
|
a695ba5ba4 | ||
|
|
71c277c17c | ||
|
|
e67f4b8d3b | ||
|
|
7181211d78 | ||
|
|
2c798fcc77 | ||
|
|
bf12905e7f | ||
|
|
e4cf87ab34 | ||
|
|
9bbc5af2a8 | ||
|
|
06e34e54f5 | ||
|
|
5a68ba4174 | ||
|
|
8e8c368019 | ||
|
|
acbd19fac5 | ||
|
|
aa56956d16 | ||
|
|
0d2cf70d1e | ||
|
|
188ac56c97 | ||
|
|
383f09b271 | ||
|
|
0af306fc54 | ||
|
|
15d5c49076 | ||
|
|
65660535bc | ||
|
|
6c41dd82f0 | ||
|
|
78d788c28c | ||
|
|
e9deeab37f | ||
|
|
4d595eec5b | ||
|
|
047b001336 | ||
|
|
840e936c7c | ||
|
|
6cc6b490be | ||
|
|
f45d1ed4f5 | ||
|
|
4ffc417858 | ||
|
|
9e74a70859 | ||
|
|
0997f2c0f1 | ||
|
|
437b545426 | ||
|
|
9b9d07572b | ||
|
|
6a90c6c2c8 | ||
|
|
8371187689 | ||
|
|
4309c17ffa | ||
|
|
183a9d4e66 | ||
|
|
c1b444e1e7 | ||
|
|
5f117384c0 |
0
libs/langchain/langchain/automaton/__init__.py
Normal file
0
libs/langchain/langchain/automaton/__init__.py
Normal file
75
libs/langchain/langchain/automaton/agent.py
Normal file
75
libs/langchain/langchain/automaton/agent.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Generalized chat agent, works with any chat model."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Optional, Sequence, TypeVar, List
|
||||
|
||||
from langchain.automaton.typedefs import MessageLike, AgentFinish, Agent
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
WorkingMemoryProcessor = Runnable[Sequence[MessageLike], List[MessageLike]]
|
||||
LLMProgram = Runnable[Sequence[MessageLike], List[MessageLike]]
|
||||
|
||||
|
||||
class SequentialAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
llm_program: LLMProgram,
|
||||
memory_processor: Optional[WorkingMemoryProcessor] = None,
|
||||
) -> None:
|
||||
"""Initialize the agent."""
|
||||
self.llm_program = llm_program
|
||||
self.memory_processor = memory_processor
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
max_iterations: int = 10,
|
||||
) -> List[MessageLike]:
|
||||
"""Run the agent."""
|
||||
messages = list(messages)
|
||||
for iteration_num in range(max_iterations):
|
||||
# Working memory / working state updates can take the form
|
||||
# of appends or replacements
|
||||
if self.memory_processor:
|
||||
# This is a replacement
|
||||
messages = self.memory_processor.invoke(messages)
|
||||
if messages and isinstance(messages[-1], AgentFinish):
|
||||
break
|
||||
|
||||
# This is an append to working memory
|
||||
messages.extend(self.llm_program.invoke(messages, config=config))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class MessageAutomaton: # Just a sketch
|
||||
def __init__(
|
||||
self,
|
||||
router: Callable[[Sequence[MessageLike]], Optional[WorkingMemoryProcessor]],
|
||||
) -> None:
|
||||
"""Initialize the automaton."""
|
||||
self.router = router
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
max_iterations: int = 10,
|
||||
) -> List[MessageLike]:
|
||||
"""Run the automaton."""
|
||||
new_messages = list(messages)
|
||||
for _ in range(max_iterations):
|
||||
runnable = self.router(new_messages)
|
||||
if not runnable:
|
||||
break
|
||||
new_messages.extend(runnable.invoke(new_messages, config=config))
|
||||
return new_messages
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import List, Sequence
|
||||
|
||||
from langchain.automaton.agent import SequentialAgent
|
||||
|
||||
from langchain.automaton.prompt_generator import AdapterBasedGenerator
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish, FunctionCallRequest, FunctionCallResponse,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema import AIMessage, FunctionMessage, Generation
|
||||
from langchain.schema.language_model import (
|
||||
BaseLanguageModel, LanguageModelInput, LanguageModelOutput,
|
||||
)
|
||||
from langchain.schema.output_parser import BaseGenerationOutputParser
|
||||
from langchain.schema.runnable import Runnable
|
||||
from langchain.tools import BaseTool, format_tool_to_openai_function
|
||||
|
||||
|
||||
class OpenAIFunctionsParser(BaseGenerationOutputParser):
|
||||
def parse_result(self, result: List[Generation]):
|
||||
if len(result) != 1:
|
||||
raise AssertionError(f"Expected exactly one result")
|
||||
first_result = result[0]
|
||||
|
||||
message = first_result.message
|
||||
|
||||
if not isinstance(message, AIMessage) or not message.additional_kwargs:
|
||||
return AgentFinish(result=message)
|
||||
|
||||
parser = JsonOutputFunctionsParser(strict=False, args_only=False)
|
||||
try:
|
||||
function_request = parser.parse_result(result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error parsing result: {result} {repr(e)}") from e
|
||||
|
||||
return FunctionCallRequest(
|
||||
name=function_request["name"],
|
||||
named_arguments=function_request["arguments"],
|
||||
)
|
||||
|
||||
|
||||
def create_openai_agent(
|
||||
llm: BaseLanguageModel[LanguageModelInput, LanguageModelOutput]
|
||||
| Runnable[LanguageModelInput, LanguageModelOutput],
|
||||
tools: Sequence[BaseTool],
|
||||
) -> SequentialAgent:
|
||||
"""Create an agent that uses OpenAI's API."""
|
||||
openai_funcs = [format_tool_to_openai_function(tool_) for tool_ in tools]
|
||||
prompt_generator = AdapterBasedGenerator(
|
||||
msg_adapters={
|
||||
FunctionCallResponse: lambda message: FunctionMessage(
|
||||
name=message.name, content=json.dumps(message.result)
|
||||
),
|
||||
# No need to translate function call requests
|
||||
},
|
||||
)
|
||||
llm_program = create_llm_program(
|
||||
llm.bind(functions=openai_funcs),
|
||||
prompt_generator=prompt_generator,
|
||||
tools=tools,
|
||||
parser=OpenAIFunctionsParser(),
|
||||
)
|
||||
return SequentialAgent(llm_program, memory_processor=None)
|
||||
@@ -0,0 +1,113 @@
|
||||
"""Generalized chat agent, works with any chat model."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Iterator,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
List,
|
||||
)
|
||||
|
||||
from langchain.automaton.processors import WorkingMemoryManager
|
||||
from langchain.automaton.runnables import create_retriever
|
||||
from langchain.automaton.typedefs import (
|
||||
Agent,
|
||||
AgentFinish,
|
||||
MessageLike,
|
||||
RetrievalRequest,
|
||||
RetrievalResponse,
|
||||
)
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def prompt_generator(input_messages: Sequence[MessageLike]) -> List[BaseMessage]:
|
||||
"""Generate a prompt from a log of message like objects."""
|
||||
messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, RetrievalResponse):
|
||||
prompt = ""
|
||||
|
||||
if message.results:
|
||||
for idx, doc in enumerate(message.results):
|
||||
prompt += f"--- Result {idx} ---\n"
|
||||
prompt += "Text:\n"
|
||||
prompt += f"{doc.page_content}"
|
||||
for field in ["title", "description", "source"]:
|
||||
if field in doc.metadata:
|
||||
prompt += f"{doc.metadata[field]}\n"
|
||||
prompt += f"--- End Result {idx} ---\n"
|
||||
else:
|
||||
prompt = "Found no results for the query."
|
||||
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content="Context: <result>\n" + prompt + "\n</result>",
|
||||
)
|
||||
)
|
||||
else:
|
||||
pass
|
||||
return messages
|
||||
|
||||
|
||||
class RagAgent(Agent):
|
||||
"""A generalized chat agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_program: Runnable[Sequence[MessageLike], List[MessageLike]],
|
||||
retriever: Optional[BaseRetriever] = None,
|
||||
memory_manager: WorkingMemoryManager = None,
|
||||
) -> None:
|
||||
"""Initialize the chat agent."""
|
||||
self.llm_program = llm_program
|
||||
self.retriever = create_retriever(retriever)
|
||||
self.memory_manager = memory_manager
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
) -> List[MessageLike]:
|
||||
"""Take a single step with the agent."""
|
||||
last_message = messages[-1] if messages else None
|
||||
if not last_message:
|
||||
return []
|
||||
match last_message:
|
||||
case AIMessage():
|
||||
return []
|
||||
case AgentFinish():
|
||||
return []
|
||||
case HumanMessage():
|
||||
return [RetrievalRequest(query=last_message.content)]
|
||||
case RetrievalRequest():
|
||||
return [self.retriever.invoke(last_message, config=config)]
|
||||
case _:
|
||||
return self.llm_program.invoke(messages, config=config)
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[dict] = None,
|
||||
max_iterations: int = 100,
|
||||
) -> Iterator[MessageLike]:
|
||||
"""Run the agent."""
|
||||
all_messages = list(messages)
|
||||
for _ in range(max_iterations):
|
||||
all_messages = self.memory_manager.process(all_messages)
|
||||
new_messages = self.step(all_messages, config=config)
|
||||
if not new_messages:
|
||||
break
|
||||
yield from new_messages
|
||||
all_messages.extend(new_messages)
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Implementation of a think act agent.
|
||||
|
||||
Uses priming messages with text based LLM.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from langchain.automaton.runnables import (
|
||||
create_llm_program,
|
||||
)
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
from langchain.automaton.typedefs import (
|
||||
AdHocMessage,
|
||||
Agent,
|
||||
AgentFinish,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
PromptValue,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool, Tool
|
||||
|
||||
TEMPLATE_ = """\
|
||||
Respond to the human as helpfully and accurately as possible. \
|
||||
You have access to the following tools:
|
||||
{tools_description}
|
||||
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
You can only use a single tool at a time.
|
||||
|
||||
<action>
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
</action>
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
<action>
|
||||
$BLOB
|
||||
</action>
|
||||
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
<action>
|
||||
{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}
|
||||
</action>
|
||||
|
||||
Begin:
|
||||
|
||||
Reminder to ALWAYS respond with a valid blob of a single action. \
|
||||
Use tools if necessary. Respond directly if appropriate. \
|
||||
Format is <action>$BLOB</action> then Observation.\
|
||||
"""
|
||||
|
||||
|
||||
def generate_memory(tools: Sequence[Tool]) -> List[MessageLike]:
|
||||
"""Set up basic memory for agent."""
|
||||
tools_info = generate_tool_info(tools)
|
||||
return [SystemMessagePromptTemplate.from_template(TEMPLATE_).format(**tools_info)]
|
||||
|
||||
|
||||
class ActionParser:
|
||||
"""A utility class to encode and decode action blocks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ActionParser."""
|
||||
self.pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
|
||||
|
||||
def decode(self, text: Union[BaseMessage, str]) -> Optional[MessageLike]:
|
||||
"""Decode the action."""
|
||||
if isinstance(text, BaseMessage):
|
||||
text = text.content
|
||||
match = self.pattern.search(text)
|
||||
if match:
|
||||
action_blob = match.group("action_blob")
|
||||
try:
|
||||
data = ast.literal_eval(action_blob)
|
||||
except SyntaxError:
|
||||
return AdHocMessage(
|
||||
type="error", data=f"Invalid action blob {action_blob}"
|
||||
)
|
||||
name = data["action"]
|
||||
|
||||
if name == "Final Answer": # Special cased "tool" for final answer
|
||||
return AgentFinish(result=data["action_input"])
|
||||
action_input = data["action_input"]
|
||||
if isinstance(action_input, str) and not action_input:
|
||||
named_arguments = {}
|
||||
elif isinstance(action_input, dict):
|
||||
named_arguments = action_input
|
||||
else:
|
||||
return AdHocMessage(
|
||||
type="error",
|
||||
data=f"Invalid action blob {action_blob}, action_input must be a dict",
|
||||
)
|
||||
|
||||
return FunctionCallRequest(
|
||||
name=data["action"], named_arguments=named_arguments or {}
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ThinkActPromptGenerator(PromptValue):
|
||||
"""think-act paradigm
|
||||
|
||||
prompt value can do additional priming of the LLM, used in to_string method right now.
|
||||
"""
|
||||
|
||||
messages: Sequence[MessageLike]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""The string variant of the prompt."""
|
||||
finalized = []
|
||||
messages = self.messages
|
||||
for idx, message in enumerate(messages):
|
||||
if isinstance(message, AdHocMessage):
|
||||
if message.type == "error":
|
||||
finalized.extend(
|
||||
[
|
||||
f"Error: Malformed <action> blob with error: {message.data}. "
|
||||
f"Please re-write the action correctly."
|
||||
"\n",
|
||||
]
|
||||
)
|
||||
continue
|
||||
if message.type != "prime":
|
||||
raise AssertionError()
|
||||
component = message.data
|
||||
finalized.append(component)
|
||||
continue
|
||||
|
||||
if isinstance(message, FunctionCallResponse):
|
||||
component = f"Observation: {message.result}"
|
||||
elif isinstance(message, HumanMessage):
|
||||
component = f"Question: {message.content.strip()}"
|
||||
elif isinstance(message, (AIMessage, SystemMessage)):
|
||||
component = message.content.strip()
|
||||
elif isinstance(message, FunctionCallRequest):
|
||||
# This is an internal message, and should not be returned to the user.
|
||||
continue
|
||||
elif isinstance(message, AgentFinish):
|
||||
component = f"Answer: {message.result}"
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
finalized.extend([component, "\n"])
|
||||
return "".join(finalized).strip()
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
messages.append(message)
|
||||
elif isinstance(message, FunctionCallResponse):
|
||||
messages.append(
|
||||
SystemMessage(content=f"Observation: `{message.result}`")
|
||||
)
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def from_message_log(cls, messages: Sequence[MessageLike]):
|
||||
"""Create a ThinkActPromptGenerator from a MessageLog, allowing positional arguments."""
|
||||
return cls(messages=messages)
|
||||
|
||||
|
||||
class ThinkActAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
*,
|
||||
max_iterations: int = 10,
|
||||
) -> None:
|
||||
"""Initialize the chat automaton."""
|
||||
self.think_act = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=ThinkActPromptGenerator.from_message_log,
|
||||
stop=["Observation:", "observation:"],
|
||||
parser=ActionParser().decode,
|
||||
tools=tools,
|
||||
invoke_tools=True,
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[dict] = None,
|
||||
max_iterations: int = 100,
|
||||
) -> Iterator[MessageLike]:
|
||||
"""Run the agent."""
|
||||
all_messages = list(messages)
|
||||
for _ in range(max_iterations):
|
||||
if all_messages and isinstance(all_messages[-1], AgentFinish):
|
||||
break
|
||||
|
||||
if all_messages and isinstance(
|
||||
all_messages[-1], (FunctionCallResponse, HumanMessage)
|
||||
):
|
||||
all_messages.append(AdHocMessage(type="prime", data="Thought:"))
|
||||
|
||||
new_messages = self.think_act.invoke(all_messages, config=config)
|
||||
yield from new_messages
|
||||
all_messages.extend(new_messages)
|
||||
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Sequence, Union, Optional
|
||||
|
||||
from langchain.automaton.agent import SequentialAgent
|
||||
from langchain.automaton.agent import WorkingMemoryProcessor
|
||||
from langchain.automaton.prompt_generator import AdapterBasedGenerator
|
||||
from langchain.automaton.runnables import create_llm_program
|
||||
from langchain.automaton.tool_utils import generate_tool_info
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
)
|
||||
from langchain.prompts import SystemMessagePromptTemplate
|
||||
from langchain.schema import BaseMessage, HumanMessage
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
TEMPLATE_ = SystemMessagePromptTemplate.from_template(
|
||||
"""Respond to the human as helpfully and accurately as \
|
||||
possible. You have access to the following tools:
|
||||
{tools_description}
|
||||
|
||||
Use a blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $BLOB, as shown.
|
||||
|
||||
<action>
|
||||
{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}
|
||||
</action>
|
||||
|
||||
When invoking a tool do not provide any clarifying information.
|
||||
|
||||
The human will forward results of tool invocations as "Observations".
|
||||
|
||||
When you know the answer paraphrase the information in the observations properly and respond to the user. \
|
||||
If you do not know the answer use more tools.
|
||||
|
||||
You can only take a single action at a time."""
|
||||
)
|
||||
|
||||
|
||||
def get_start_state(tools: Sequence[BaseTool]) -> List[BaseMessage]:
|
||||
"""Generate a prompt for the agent."""
|
||||
tool_info = generate_tool_info(tools)
|
||||
msg = TEMPLATE_.format(**tool_info)
|
||||
return [msg]
|
||||
|
||||
|
||||
def _decode(text: Union[BaseMessage, str]) -> MessageLike:
|
||||
"""Decode the action."""
|
||||
pattern = re.compile(r"<action>(?P<action_blob>.*?)<\/action>", re.DOTALL)
|
||||
if not isinstance(text, BaseMessage):
|
||||
raise NotImplementedError()
|
||||
_text = text.content
|
||||
match = pattern.search(_text)
|
||||
if match:
|
||||
action_blob = match.group("action_blob")
|
||||
data = ast.literal_eval(action_blob)
|
||||
name = data["action"]
|
||||
if name == "Final Answer": # Special cased "tool" for final answer
|
||||
return AgentFinish(result=data["action_input"])
|
||||
return FunctionCallRequest(
|
||||
name=data["action"], named_arguments=data["action_input"] or {}
|
||||
)
|
||||
else:
|
||||
return AgentFinish(result=text)
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_xml_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tools: Sequence[BaseTool],
|
||||
memory_processor: Optional[WorkingMemoryProcessor] = None,
|
||||
) -> SequentialAgent:
|
||||
"""XML based chat agent."""
|
||||
prompt_generator = AdapterBasedGenerator(
|
||||
msg_adapters={
|
||||
FunctionCallResponse: lambda msg: HumanMessage(
|
||||
content=f"Observation: {msg.result}"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
llm_program = create_llm_program(
|
||||
llm,
|
||||
prompt_generator=prompt_generator,
|
||||
tools=tools,
|
||||
parser=_decode,
|
||||
)
|
||||
return SequentialAgent(
|
||||
llm_program,
|
||||
memory_processor=memory_processor,
|
||||
)
|
||||
325
libs/langchain/langchain/automaton/notebooks/chat_agent.ipynb
Normal file
325
libs/langchain/langchain/automaton/notebooks/chat_agent.ipynb
Normal file
@@ -0,0 +1,325 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# General Chat Agent w/ XML like encoding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '../../../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.tools import tool, Tool\n",
|
||||
"from langchain.schema.messages import AIMessage, HumanMessage\n",
|
||||
"from langchain.chat_models.anthropic import ChatAnthropic\n",
|
||||
"from langchain.automaton.agent_implementations import xml_agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add_3(x: int, y: int, z: int) -> int:\n",
|
||||
" \"\"\"Add 3 numbers together.\"\"\"\n",
|
||||
" return x + y + z\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "97a59c88-2f0a-4982-81e3-4ffba4f222da",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatAnthropic(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "00177716-422e-436c-b07d-b70b2ccb5839",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_messages = xml_agent.get_start_state(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"id": "778b2eda-5128-40ea-949f-9727610cc194",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = xml_agent.create_xml_agent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"id": "3bfcb9a7-af38-4b7a-87cc-84a842d6c072",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = HumanMessage(\n",
|
||||
" content=\"Sum of first 10 numbers starting from 1? use one tool at a time\"\n",
|
||||
")\n",
|
||||
"messages = start_messages + [question]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "29ec7747-5cf7-49e3-b39e-23cf0a0818f3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\", \n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 1,\n",
|
||||
" \"y\": 2\n",
|
||||
" } \n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
|
||||
"FunctionResult(name=add, result=3, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 3, \n",
|
||||
" \"y\": 3\n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
|
||||
"FunctionResult(name=add, result=6, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 6,\n",
|
||||
" \"y\": 4 \n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
|
||||
"FunctionResult(name=add, result=10, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 10, \n",
|
||||
" \"y\": 5\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action> \n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
|
||||
"FunctionResult(name=add, result=15, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 15,\n",
|
||||
" \"y\": 6\n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
|
||||
"FunctionResult(name=add, result=21, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 21,\n",
|
||||
" \"y\": 7\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
|
||||
"FunctionResult(name=add, result=28, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 28, \n",
|
||||
" \"y\": 8\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
|
||||
"FunctionResult(name=add, result=36, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 36,\n",
|
||||
" \"y\": 9 \n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
|
||||
"FunctionResult(name=add, result=45, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"add\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"x\": 45,\n",
|
||||
" \"y\": 10\n",
|
||||
" }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
|
||||
"FunctionResult(name=add, result=55, error=None)\n",
|
||||
"ai: <action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": {\n",
|
||||
" \"answer\": \"The sum of the first 10 numbers starting from 1 is 55.\"\n",
|
||||
" }\n",
|
||||
"} \n",
|
||||
"</action>\n",
|
||||
"AgentFinish(result={'answer': 'The sum of the first 10 numbers starting from 1 is 55.'})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for msg in agent.run(messages):\n",
|
||||
" print(msg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "a3a4eeca-9126-4d97-902d-fa4564563130",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "ImportError",
|
||||
"evalue": "cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[40], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Visualize\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mhtml\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mlangchain\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautomaton\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mchat_agent\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m prompt_generator\n\u001b[1;32m 6\u001b[0m messages \u001b[38;5;241m=\u001b[39m prompt_generator(message_log)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;66;03m# Generate HTML divs\u001b[39;00m\n",
|
||||
"\u001b[0;31mImportError\u001b[0m: cannot import name 'prompt_generator' from 'langchain.automaton.chat_agent' (/home/eugene/src/langchain/libs/langchain/langchain/automaton/notebooks/../../../langchain/automaton/chat_agent.py)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# # Visualize\n",
|
||||
"\n",
|
||||
"# import html\n",
|
||||
"# from langchain.automaton.chat_agent import prompt_generator\n",
|
||||
"\n",
|
||||
"# messages = prompt_generator(message_log)\n",
|
||||
"\n",
|
||||
"# # Generate HTML divs\n",
|
||||
"# html_divs = []\n",
|
||||
"# for idx, msg in enumerate(messages):\n",
|
||||
"# background_color = \"DodgerBlue\" if idx % 2 == 0 else \"Tomato\"\n",
|
||||
"# div_content = f\"\"\"\n",
|
||||
"# <div style=\"border: 1px solid black; padding: 10px; background-color:{background_color}\">\n",
|
||||
"# <div style=\"display: flex;\">\n",
|
||||
"# <div style=\"font-weight: bold; margin-right: 10px;\">{html.escape(msg.type)}</div>\n",
|
||||
"# <div>{html.escape(msg.content)}</div>\n",
|
||||
"# </div>\n",
|
||||
"# </div>\n",
|
||||
"# \"\"\"\n",
|
||||
"# html_divs.append(div_content)\n",
|
||||
"\n",
|
||||
"# # Display HTML divs in a Jupyter Notebook cell\n",
|
||||
"# from IPython.display import HTML, display\n",
|
||||
"\n",
|
||||
"# html_output = \"\\n\".join(html_divs)\n",
|
||||
"# display(HTML(html_output))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
167
libs/langchain/langchain/automaton/notebooks/openai.ipynb
Normal file
167
libs/langchain/langchain/automaton/notebooks/openai.ipynb
Normal file
@@ -0,0 +1,167 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a91c50ff-bcb9-484b-a4bd-f56e75c182ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Open AI Chat"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5d77b586-fced-4a34-b778-b05141ab32bc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '../../../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "f73ed8ca-7384-4afc-af8c-bfd5b794df01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.tools import tool, Tool\n",
|
||||
"from langchain.automaton.agent_implementations.openai_agent import create_openai_agent\n",
|
||||
"from langchain.schema.messages import HumanMessage\n",
|
||||
"from langchain.chat_models.openai import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10} \"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add_3(x: int, y: int, z: int) -> int:\n",
|
||||
" \"\"\"Add 3 numbers together.\"\"\"\n",
|
||||
" return x + y + z\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "a0790c55-eacb-4e0c-970b-5784d62ef5c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0)\n",
|
||||
"agent = create_openai_agent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "e7d1ec69-de12-4c66-abee-3b2bbabeec13",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a = agent.run([HumanMessage(content=\"Sum of first 10 numbers starting from 1? use one tool at a time\")])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "75c6cc2c-7604-44c0-b8cf-4513620782f9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 1,\\n \"y\": 2\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 1, 'y': 2})\n",
|
||||
"FunctionResult(name=add, result=3, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 3,\\n \"y\": 3\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 3, 'y': 3})\n",
|
||||
"FunctionResult(name=add, result=6, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 6,\\n \"y\": 4\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 6, 'y': 4})\n",
|
||||
"FunctionResult(name=add, result=10, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 10,\\n \"y\": 5\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 10, 'y': 5})\n",
|
||||
"FunctionResult(name=add, result=15, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 15,\\n \"y\": 6\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 15, 'y': 6})\n",
|
||||
"FunctionResult(name=add, result=21, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 21,\\n \"y\": 7\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 21, 'y': 7})\n",
|
||||
"FunctionResult(name=add, result=28, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 28,\\n \"y\": 8\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 28, 'y': 8})\n",
|
||||
"FunctionResult(name=add, result=36, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 36,\\n \"y\": 9\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 36, 'y': 9})\n",
|
||||
"FunctionResult(name=add, result=45, error=None)\n",
|
||||
"ai: {'function_call': {'name': 'add', 'arguments': '{\\n \"x\": 45,\\n \"y\": 10\\n}'}}\n",
|
||||
"FunctionCall(name=add, named_arguments={'x': 45, 'y': 10})\n",
|
||||
"FunctionResult(name=add, result=55, error=None)\n",
|
||||
"ai: The sum of the first 10 numbers starting from 1 is 55.\n",
|
||||
"AgentFinish(result=ai: The sum of the first 10 numbers starting from 1 is 55.)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for msg in a:\n",
|
||||
" print(msg)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9875f554-8f4d-409b-929f-00e8bc70185d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Think Act Agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "80595f69-8b46-494e-bd2f-0c9538ed526a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2\n",
|
||||
"\n",
|
||||
"import sys\n",
|
||||
"sys.path.insert(0, '../../../')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "724c11c6-c98a-47ef-8697-87cd888df7a0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import List, Optional, Any\n",
|
||||
"from langchain.tools import tool, Tool\n",
|
||||
"from langchain.schema.messages import AIMessage, HumanMessage\n",
|
||||
"from langchain.chat_models.anthropic import ChatAnthropic\n",
|
||||
"from langchain.llms import OpenAI\n",
|
||||
"from langchain.automaton.agent_implementations import think_act_agent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ed3855cd-06b0-4a45-8b76-88926433fc60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_tools() -> List[Tool]:\n",
|
||||
" @tool\n",
|
||||
" def name() -> str:\n",
|
||||
" \"\"\"Use to look up the user's name\"\"\"\n",
|
||||
" return \"Eugene\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def get_weather(city: str) -> str:\n",
|
||||
" \"\"\"Get weather in a specific city.\"\"\"\n",
|
||||
" return \"42F and sunny\"\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def add(x: int, y: int) -> int:\n",
|
||||
" \"\"\"Use to add two numbers. For example; { \"x\": 2, \"y\": 10}\"\"\"\n",
|
||||
" return x + y\n",
|
||||
"\n",
|
||||
" @tool\n",
|
||||
" def divide(x: float, y: float) -> float:\n",
|
||||
" \"\"\"Divide x by y\"\"\"\n",
|
||||
" return x / y\n",
|
||||
"\n",
|
||||
" return list(locals().values())\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = get_tools()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "f6abff39-5243-45d1-a250-a862bbd665ff",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = OpenAI(temperature=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "d416bb0e-28e4-4da3-8b10-aa563f87a48e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"agent = think_act_agent.ThinkActAgent(llm, tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "f9c1ed10-309b-47a1-b22e-153e53c3a86e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"start_messages = think_act_agent.generate_memory(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "04488a11-6ef2-428b-b1ae-7800b424480a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.callbacks.manager import (\n",
|
||||
" trace_as_chain_group,\n",
|
||||
" atrace_as_chain_group,\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "32a3ce50-e993-4f29-b2f8-02261d417e24",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ai: I need to use the divide tool and the get_weather tool\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"divide\",\n",
|
||||
" \"action_input\": { \"x\": 12381237, \"y\": 23 }\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"FunctionCall(name=divide, named_arguments={'x': 12381237, 'y': 23})\n",
|
||||
"FunctionResult(name=divide, result=538314.6521739131, error=None)\n",
|
||||
"ai: I need to use the name tool and the get_weather tool\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"name\",\n",
|
||||
" \"action_input\": \"\"\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"FunctionCall(name=name, named_arguments={})\n",
|
||||
"FunctionResult(name=name, result=Eugene, error=None)\n",
|
||||
"ai: I need to use the get_weather tool\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"get_weather\",\n",
|
||||
" \"action_input\": \"New York\"\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"AdHocMessage(type=error, data=Invalid action blob \n",
|
||||
"{\n",
|
||||
" \"action\": \"get_weather\",\n",
|
||||
" \"action_input\": \"New York\"\n",
|
||||
"}\n",
|
||||
", action_input must be a dict)\n",
|
||||
"ai: \n",
|
||||
"Thought:I need to use the get_weather tool\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"get_weather\",\n",
|
||||
" \"action_input\": {\"city\": \"New York\"}\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"\n",
|
||||
"FunctionCall(name=get_weather, named_arguments={'city': 'New York'})\n",
|
||||
"FunctionResult(name=get_weather, result=42F and sunny, error=None)\n",
|
||||
"ai: I know what to respond\n",
|
||||
"<action>\n",
|
||||
"{\n",
|
||||
" \"action\": \"Final Answer\",\n",
|
||||
" \"action_input\": \"Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.\"\n",
|
||||
"}\n",
|
||||
"</action>\n",
|
||||
"AgentFinish(result=Hi Eugene, the answer to 12381237/23 is 538314.6521739131. The weather in New York is 42F and sunny.)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = HumanMessage(content=\"what is 12381237/23? What is my name? And what is the weather in new york?\")\n",
|
||||
"messages = start_messages + [question]\n",
|
||||
"\n",
|
||||
"with trace_as_chain_group(\"think_act_agent\") as group_manager:\n",
|
||||
" for msg in agent.run(messages, config={\"callbacks\": group_manager}):\n",
|
||||
" print(msg)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "248d63ee-1ec4-49cb-9e6b-84e75a7644ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
62
libs/langchain/langchain/automaton/processors.py
Normal file
62
libs/langchain/langchain/automaton/processors.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Code for processing the working memory."""
|
||||
import abc
|
||||
from typing import List, Sequence, Callable
|
||||
|
||||
from langchain.schema.messages import get_buffer_string, SystemMessage
|
||||
|
||||
from langchain.automaton.typedefs import MessageLike
|
||||
from langchain.automaton.prompt_generator import PromptGenerator
|
||||
|
||||
|
||||
class WorkingMemoryManager(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def process(self, messages: Sequence[MessageLike]) -> List[MessageLike]:
|
||||
"""Process the working memory returning a potentially new working memory."""
|
||||
|
||||
|
||||
class TokenBufferProcessor(WorkingMemoryManager):
|
||||
"""Trim the working memory to a maximum context length (aka token limit)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_token_limit: int,
|
||||
token_counter: Callable[[str], int],
|
||||
prompt_generator: PromptGenerator,
|
||||
skip_system_messages: bool = True,
|
||||
) -> None:
|
||||
"""Token counter."""
|
||||
self.token_counter = token_counter
|
||||
self.max_token_limit = max_token_limit
|
||||
self.prompt_generator = prompt_generator
|
||||
self.skip_system_messages = skip_system_messages
|
||||
|
||||
def process(self, messages: Sequence[MessageLike]) -> List[MessageLike]:
|
||||
"""Update the working memory with the given messages."""
|
||||
# Work backwards from the end of the buffer dropping messages until
|
||||
messages = list(messages)
|
||||
curr_buffer_length = self.count_across_messages(messages)
|
||||
|
||||
idx = 0
|
||||
|
||||
while curr_buffer_length > self.max_token_limit:
|
||||
if idx >= len(messages):
|
||||
raise AssertionError("No messages left in buffer")
|
||||
|
||||
if isinstance(messages[idx], SystemMessage):
|
||||
idx += 1
|
||||
continue
|
||||
|
||||
messages.pop(idx) # Drop the first message
|
||||
if not messages:
|
||||
raise AssertionError("No messages left in buffer")
|
||||
curr_buffer_length = self.count_across_messages(messages)
|
||||
return messages
|
||||
|
||||
def count_across_messages(self, messages: Sequence[MessageLike]) -> int:
|
||||
"""Count the number of tokens across messages."""
|
||||
if isinstance(self.prompt_generator, PromptGenerator):
|
||||
base_messages = self.prompt_generator.to_messages(messages)
|
||||
else:
|
||||
base_messages = self.prompt_generator(messages)
|
||||
buffer_string = get_buffer_string(base_messages)
|
||||
return self.token_counter(buffer_string)
|
||||
92
libs/langchain/langchain/automaton/prompt_generator.py
Normal file
92
libs/langchain/langchain/automaton/prompt_generator.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Prompt generation for the automaton."""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Mapping, Any, Callable, List, Sequence, Optional, Union
|
||||
|
||||
from langchain.automaton.typedefs import MessageLike
|
||||
from langchain.schema import BaseMessage, PromptValue
|
||||
|
||||
|
||||
class BoundPromptValue(PromptValue):
|
||||
"""A prompt value that is bound to a specific value."""
|
||||
|
||||
as_string: Callable[[], str]
|
||||
as_messages: Callable[[], List[BaseMessage]]
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt value as string."""
|
||||
return self.as_string()
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as a list of Messages."""
|
||||
return self.as_messages()
|
||||
|
||||
|
||||
class PromptGenerator(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def to_messages(
|
||||
self, original_messages: Sequence[MessageLike]
|
||||
) -> List[BaseMessage]:
|
||||
"""Generate a prompt from message like objects."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_string(self, original_messages: Sequence[MessageLike]) -> str:
|
||||
"""Generate a prompt from message like objects."""
|
||||
|
||||
def to_prompt_value(self, original_messages: Sequence[MessageLike]) -> PromptValue:
|
||||
"""Generate a prompt from message like objects."""
|
||||
return BoundPromptValue(
|
||||
as_string=lambda: self.to_string(original_messages),
|
||||
as_messages=lambda: self.to_messages(original_messages),
|
||||
)
|
||||
|
||||
|
||||
class AdapterBasedGenerator(PromptGenerator):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
msg_adapters: Optional[
|
||||
Mapping[Any, Callable[[MessageLike], Union[BaseMessage, List[BaseMessage]]]]
|
||||
] = None,
|
||||
str_adapters: Optional[Mapping[Any, Callable[[MessageLike], str]]] = None,
|
||||
) -> None:
|
||||
"""Initialize the adapter based generator."""
|
||||
self.msg_adapters = msg_adapters or {}
|
||||
self.str_adapters = str_adapters or {}
|
||||
|
||||
def to_messages(self, messages: Sequence[MessageLike]) -> List[BaseMessage]:
|
||||
"""Generate a prompt from message like objects."""
|
||||
new_messages = []
|
||||
|
||||
for original_message in messages:
|
||||
adapter = self.msg_adapters.get(type(original_message), None)
|
||||
if adapter:
|
||||
translated = adapter(original_message)
|
||||
if isinstance(translated, BaseMessage):
|
||||
new_messages.append(translated)
|
||||
else:
|
||||
new_messages.extend(translated)
|
||||
continue
|
||||
|
||||
if isinstance(original_message, BaseMessage):
|
||||
# Only adds BaseMessages by default,
|
||||
# internal messages are ignored
|
||||
new_messages.append(original_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def to_string(self, messages: Sequence[MessageLike]) -> str:
|
||||
"""Generate a prompt from message like objects."""
|
||||
string_prompts = []
|
||||
for original_message in messages:
|
||||
adapter = self.str_adapters.get(type(original_message), None)
|
||||
if adapter:
|
||||
string_prompts.extend(adapter(original_message))
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"String adapter not found for {type(original_message)}"
|
||||
)
|
||||
|
||||
return "\n".join(string_prompts)
|
||||
262
libs/langchain/langchain/automaton/runnables.py
Normal file
262
libs/langchain/langchain/automaton/runnables.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""Module contains useful runnables for agents."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.automaton.typedefs import (
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
MessageLike,
|
||||
RetrievalResponse,
|
||||
RetrievalRequest,
|
||||
)
|
||||
from langchain.automaton.prompt_generator import PromptGenerator
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.schema import AIMessage, BaseMessage, PromptValue, Document
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output_parser import BaseOutputParser
|
||||
from langchain.schema.runnable import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnablePassthrough,
|
||||
patch_config,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _to_message(result: Union[BaseMessage, str]) -> BaseMessage:
|
||||
"""Convert to a list of messages."""
|
||||
if isinstance(result, BaseMessage):
|
||||
return result
|
||||
elif isinstance(result, str):
|
||||
return AIMessage(content=result)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported type {type(result)}")
|
||||
|
||||
|
||||
def _to_list(element: Union[None, T, List[T]]) -> List[T]:
|
||||
"""Convert to a sequence."""
|
||||
if element is None:
|
||||
return []
|
||||
elif isinstance(element, list):
|
||||
return element
|
||||
else:
|
||||
return [element]
|
||||
|
||||
|
||||
def _to_runnable_parser(parser: Optional[BaseOutputParser]) -> Runnable:
|
||||
"""Adapt a parser to a runnable."""
|
||||
if parser is None:
|
||||
# Then create a runnable that returns no messages
|
||||
return RunnableLambda(lambda *args, **kwargs: None)
|
||||
elif isinstance(parser, Runnable):
|
||||
return parser
|
||||
elif isinstance(parser, Callable):
|
||||
return RunnableLambda(parser)
|
||||
else:
|
||||
raise ValueError(f"Expected BaseOutputParser, got {parser}")
|
||||
|
||||
|
||||
def _concatenate_head_and_tail(intermediate_input: Dict[str, Any]) -> List[BaseMessage]:
|
||||
"""Concatenate head and tail into a single list."""
|
||||
head = _to_list(intermediate_input["head"])
|
||||
tail = _to_list(intermediate_input["tail"])
|
||||
return head + tail
|
||||
|
||||
|
||||
def _apply_and_concat(
|
||||
head: Union[Runnable, Callable], tail: Union[Runnable, Callable]
|
||||
) -> Runnable:
|
||||
"""Apply head and tail and concatenate the results.
|
||||
|
||||
Note: Probably generalize to _apply(funcs) and _concatenate runnables
|
||||
|
||||
Args:
|
||||
head: A runnable or callable
|
||||
tail: A runnable or callable
|
||||
|
||||
Returns:
|
||||
A runnable that applies head and tail and concatenates the results in order.
|
||||
"""
|
||||
head_ = head if isinstance(head, Runnable) else RunnableLambda(head)
|
||||
tail_ = tail if isinstance(tail, Runnable) else RunnableLambda(tail)
|
||||
|
||||
return (
|
||||
RunnableMap(
|
||||
steps={
|
||||
"head": head_,
|
||||
"tail": tail_,
|
||||
}
|
||||
)
|
||||
| _concatenate_head_and_tail
|
||||
)
|
||||
|
||||
|
||||
def _to_retriever_input(message: MessageLike) -> str:
|
||||
"""Convert a message to a retriever input."""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
elif isinstance(message, BaseMessage):
|
||||
return message.content
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported type {type(message)}")
|
||||
|
||||
|
||||
# PUBLIC API
|
||||
|
||||
|
||||
def create_tool_invoker(
|
||||
tools: Sequence[BaseTool],
|
||||
) -> Runnable[MessageLike, Optional[FunctionCallResponse]]:
|
||||
"""See if possible to re-write with router
|
||||
|
||||
TODO:
|
||||
* re-write with router
|
||||
* potentially remove hack replace MessageLike with FunctionCall, requires
|
||||
a branching runnable
|
||||
"""
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
def func(
|
||||
function_call: MessageLike,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Optional[FunctionCallResponse]:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
if not isinstance(
|
||||
function_call, FunctionCallRequest
|
||||
): # TODO(Hack): Workaround lack of conditional apply
|
||||
return None
|
||||
try:
|
||||
tool = tools_by_name[function_call.name]
|
||||
except KeyError:
|
||||
raise AssertionError(f"No such tool: {function_call.name}")
|
||||
try:
|
||||
result = tool.invoke(
|
||||
function_call.named_arguments or {},
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
error = None
|
||||
except Exception as e:
|
||||
result = None
|
||||
error = repr(e) + repr(function_call.named_arguments)
|
||||
|
||||
return FunctionCallResponse(name=function_call.name, result=result, error=error)
|
||||
|
||||
async def afunc(
|
||||
function_call: MessageLike,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Optional[FunctionCallResponse]:
|
||||
"""A function that can invoke a tool using .run"""
|
||||
if not isinstance(
|
||||
function_call, FunctionCallRequest
|
||||
): # TODO(Hack): Workaround lack of conditional apply
|
||||
return None
|
||||
try:
|
||||
tool = tools_by_name[function_call.name]
|
||||
except KeyError:
|
||||
raise AssertionError(f"No such tool: {function_call.name}")
|
||||
try:
|
||||
result = await tool.ainvoke(
|
||||
function_call.named_arguments or {},
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
)
|
||||
error = None
|
||||
except Exception as e:
|
||||
result = None
|
||||
error = repr(e) + repr(function_call.named_arguments)
|
||||
|
||||
return FunctionCallResponse(name=function_call.name, result=result, error=error)
|
||||
|
||||
return RunnableLambda(func=func, afunc=afunc)
|
||||
|
||||
|
||||
def create_llm_program(
|
||||
llm: BaseLanguageModel,
|
||||
prompt_generator: Union[
|
||||
Callable[
|
||||
[Sequence[MessageLike]], Union[str, PromptValue, Sequence[BaseMessage]]
|
||||
],
|
||||
PromptGenerator,
|
||||
Runnable,
|
||||
],
|
||||
*,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
stop: Optional[Sequence[str]] = None,
|
||||
parser: Union[
|
||||
Runnable[Union[BaseMessage, str], MessageLike],
|
||||
Callable[[Union[BaseMessage, str]], MessageLike],
|
||||
BaseOutputParser,
|
||||
None,
|
||||
] = None,
|
||||
invoke_tools: bool = True, # TODO(Eugene): Perhaps remove.
|
||||
) -> Runnable[Sequence[MessageLike], List[MessageLike]]:
|
||||
"""Create a runnable that provides a generalized interface to an LLM with actions.
|
||||
|
||||
Args:
|
||||
llm: A language model
|
||||
prompt_generator: A function that takes a list of messages and returns a prompt
|
||||
tools: A list of tools to invoke
|
||||
stop: optional list of stop tokens
|
||||
parser: optional parser to apply to the output of the LLM
|
||||
invoke_tools: Whether to invoke tools on the output of the LLM
|
||||
|
||||
Returns:
|
||||
A runnable that returns a list of new messages
|
||||
"""
|
||||
|
||||
if isinstance(prompt_generator, PromptGenerator):
|
||||
_prompt_generator = RunnableLambda(prompt_generator.to_prompt_value)
|
||||
elif isinstance(prompt_generator, Runnable):
|
||||
_prompt_generator = prompt_generator
|
||||
else: # Otherwise attempt to convert to a runnable lambda
|
||||
_prompt_generator = RunnableLambda(prompt_generator)
|
||||
|
||||
if stop:
|
||||
llm = llm.bind(stop=stop)
|
||||
|
||||
# Add parser to the end of the chain and concatenate original llm output
|
||||
# with the parser output.
|
||||
# The parser is always created even if it is None, to make sure that
|
||||
# the _to_message adapter is always applied (regardless of the parser).
|
||||
_parser = _to_runnable_parser(parser)
|
||||
|
||||
chain = _prompt_generator | llm | _apply_and_concat(_to_message, _parser)
|
||||
|
||||
# Add tool invoker to the end of the chain.
|
||||
if invoke_tools and tools:
|
||||
tool_invoker = create_tool_invoker(tools)
|
||||
invoke_on_last = RunnableLambda(lambda msgs: msgs[-1]) | tool_invoker
|
||||
complete_chain = chain | _apply_and_concat(
|
||||
RunnablePassthrough(), invoke_on_last
|
||||
)
|
||||
else:
|
||||
complete_chain = chain
|
||||
|
||||
return complete_chain
|
||||
|
||||
|
||||
def create_retriever(
|
||||
base_retriever: BaseRetriever,
|
||||
) -> Runnable[RetrievalRequest, RetrievalResponse]:
|
||||
"""Create a runnable retriever that uses messages."""
|
||||
|
||||
def _from_retrieval_request(request: RetrievalRequest) -> str:
|
||||
"""Convert a message to a list of documents."""
|
||||
return request.query
|
||||
|
||||
def _to_retrieval_result(docs: List[Document]) -> RetrievalResponse:
|
||||
"""Convert a list of documents to a message."""
|
||||
return RetrievalResponse(results=docs)
|
||||
|
||||
return (
|
||||
RunnableLambda(_from_retrieval_request)
|
||||
| base_retriever
|
||||
| RunnableLambda(_to_retrieval_result)
|
||||
)
|
||||
42
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
42
libs/langchain/langchain/automaton/tests/test_mrkl_agent.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import langchain.automaton.agent_implementations.xml_agent
|
||||
from langchain.agents.structured_chat.output_parser import StructuredChatOutputParser
|
||||
from langchain.automaton.mrkl_agent import ActionParser
|
||||
|
||||
|
||||
def test_structured_output_chat() -> None:
|
||||
parser = StructuredChatOutputParser()
|
||||
output = parser.parse(
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action": "hello",
|
||||
"action_input": {
|
||||
"a": 2
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
)
|
||||
assert output == {}
|
||||
|
||||
|
||||
def test_parser() -> None:
|
||||
"""Tes the parser."""
|
||||
sample_text = """
|
||||
Some text before
|
||||
<action>
|
||||
{
|
||||
"key": "value",
|
||||
"number": 42
|
||||
}
|
||||
</action>
|
||||
Some text after
|
||||
"""
|
||||
action_parser = ActionParser()
|
||||
action = langchain.automaton.agent_implementations.xml_agent.decode(sample_text)
|
||||
assert action == {
|
||||
"key": "value",
|
||||
"number": 42,
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.automaton.agent_implementations.openai_agent import OpenAIAgent
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatModel,
|
||||
construct_func_invocation_message,
|
||||
)
|
||||
from langchain.automaton.typedefs import (
|
||||
AgentFinish,
|
||||
FunctionCallRequest,
|
||||
FunctionCallResponse,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.tools import Tool, tool
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tools() -> List[BaseTool]:
|
||||
@tool
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
return cast(List[Tool], [get_time, get_location])
|
||||
|
||||
|
||||
def test_openai_agent(tools: List[Tool]) -> None:
|
||||
get_time, get_location = tools
|
||||
llm = FakeChatModel(
|
||||
message_iter=iter(
|
||||
[
|
||||
construct_func_invocation_message(get_time, {}),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
agent = OpenAIAgent(llm=llm, tools=tools, max_iterations=10)
|
||||
|
||||
message_log = MessageLog(
|
||||
[
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
expected_messages = [
|
||||
SystemMessage(
|
||||
content="What time is it?",
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "get_time",
|
||||
"arguments": "{}",
|
||||
}
|
||||
},
|
||||
),
|
||||
FunctionCallRequest(
|
||||
name="get_time",
|
||||
arguments={},
|
||||
),
|
||||
FunctionCallResponse(
|
||||
name="get_time",
|
||||
result="9 PM",
|
||||
error=None,
|
||||
),
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
AgentFinish(
|
||||
AIMessage(
|
||||
content="The time is 9 PM.",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
agent.run(message_log)
|
||||
assert message_log.messages == expected_messages
|
||||
153
libs/langchain/langchain/automaton/tests/test_parsers.py
Normal file
153
libs/langchain/langchain/automaton/tests/test_parsers.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Test parsers"""
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.automaton.runnables import (
|
||||
RunnablePassthrough,
|
||||
_apply_and_concat,
|
||||
_to_list,
|
||||
_to_runnable_parser,
|
||||
create_llm_program,
|
||||
)
|
||||
from langchain.automaton.tests.utils import (
|
||||
FakeChatModel,
|
||||
)
|
||||
from langchain.automaton.typedefs import FunctionCallRequest, FunctionCallResponse, MessageLike
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from langchain.schema.runnable import RunnableLambda
|
||||
from langchain.tools import BaseTool, tool
|
||||
|
||||
|
||||
def test_apply_and_concat() -> None:
|
||||
"""Test apply and concat."""
|
||||
msg = HumanMessage(content="Hello")
|
||||
llm = RunnableLambda(lambda *args, **kwargs: msg)
|
||||
|
||||
# Test that it works with a runnable
|
||||
chain = llm | _apply_and_concat(RunnablePassthrough(), RunnablePassthrough())
|
||||
assert chain.invoke({}) == [msg, msg]
|
||||
|
||||
chain = llm | _apply_and_concat(
|
||||
lambda msg: msg.content[0], lambda msg: msg.content[1]
|
||||
)
|
||||
assert chain.invoke({}) == ["H", "e"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_value, expected_output",
|
||||
[
|
||||
(None, []), # Test when input is None
|
||||
([1, 2, 3], [1, 2, 3]), # Test when input is a list of integers
|
||||
(5, [5]), # Test when input is a single integer
|
||||
(["a", "b", "c"], ["a", "b", "c"]), # Test when input is a list of strings
|
||||
("xyz", ["xyz"]), # Test when input is a single string
|
||||
([], []), # Test when input is an empty list
|
||||
],
|
||||
)
|
||||
def test_to_list(input_value: Any, expected_output: List) -> None:
|
||||
assert _to_list(input_value) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parser, expected_output",
|
||||
[
|
||||
(None, None),
|
||||
(RunnablePassthrough(), AIMessage(content="Hello")),
|
||||
],
|
||||
)
|
||||
def to_runnable_parser(parser: Any, expected_output: Optional[BaseMessage]) -> None:
|
||||
"""To runnable parser."""
|
||||
parser_ = _to_runnable_parser(parser)
|
||||
assert parser_.invoke(AIMessage(content="Hello")) == [expected_output]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tools() -> List[BaseTool]:
|
||||
"""Make a tools fixture."""
|
||||
|
||||
@tool
|
||||
def get_time() -> str:
|
||||
"""Get time."""
|
||||
return "9 PM"
|
||||
|
||||
@tool
|
||||
def get_location() -> str:
|
||||
"""Get location."""
|
||||
return "the park"
|
||||
|
||||
return cast(List[BaseTool], [get_time, get_location])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_llm() -> BaseLanguageModel:
|
||||
"""Make a fake chat model."""
|
||||
llm = FakeChatModel(
|
||||
message_iter=iter(
|
||||
[
|
||||
AIMessage(
|
||||
content="Hello",
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def test_simple_llm_program(fake_llm: BaseLanguageModel) -> None:
|
||||
"""Test simple llm program with no parser or tools."""
|
||||
get_time, _ = tools
|
||||
program = create_llm_program(
|
||||
fake_llm,
|
||||
prompt_generator=lambda x: x,
|
||||
)
|
||||
assert program.invoke("What time is it?") == [AIMessage(content="Hello")]
|
||||
|
||||
|
||||
def test_llm_program_with_parser(fake_llm: BaseLanguageModel) -> None:
|
||||
"""Test simple llm program with no parser or tools."""
|
||||
parser = RunnableLambda(lambda msg: AIMessage(content=msg.content + " parsed"))
|
||||
program = create_llm_program(
|
||||
fake_llm,
|
||||
prompt_generator=lambda x: x,
|
||||
parser=parser,
|
||||
)
|
||||
assert program.invoke("What time is it?") == [
|
||||
AIMessage(content="Hello"),
|
||||
AIMessage(content="Hello parsed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parser, output",
|
||||
[
|
||||
(
|
||||
RunnableLambda(lambda msg: AIMessage(content="Goodbye")),
|
||||
[AIMessage(content="Hello"), AIMessage(content="Goodbye")],
|
||||
),
|
||||
(
|
||||
RunnableLambda(lambda msg: FunctionCallRequest(name="get_time")),
|
||||
[
|
||||
AIMessage(content="Hello"),
|
||||
FunctionCallRequest(name="get_time"),
|
||||
FunctionCallResponse(result="9 PM", name="get_time"),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_llm_program_with_parser_and_tools(
|
||||
tools: List[BaseTool],
|
||||
fake_llm: BaseLanguageModel,
|
||||
parser: Any,
|
||||
output: List[MessageLike],
|
||||
) -> None:
|
||||
"""Test simple llm program with no parser or tools."""
|
||||
program = create_llm_program(
|
||||
fake_llm,
|
||||
prompt_generator=lambda x: x,
|
||||
parser=parser,
|
||||
tools=tools,
|
||||
invoke_tools=True,
|
||||
)
|
||||
assert program.invoke("What time is it?") == output
|
||||
@@ -0,0 +1,42 @@
|
||||
from langchain.automaton.prompt_generation import AdapterBasedTranslator
|
||||
from langchain.schema.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.automaton.typedefs import FunctionCallResponse
|
||||
|
||||
|
||||
def test_adapter_based_generator() -> None:
|
||||
translator = AdapterBasedTranslator()
|
||||
|
||||
assert translator.to_messages(
|
||||
[
|
||||
SystemMessage(content="System"),
|
||||
AIMessage(content="Hi"),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="System"),
|
||||
AIMessage(content="Hi"),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
|
||||
translator = AdapterBasedTranslator(
|
||||
msg_adapters={
|
||||
HumanMessage: lambda m: AIMessage(content=m.content),
|
||||
FunctionCallResponse: lambda m: HumanMessage(
|
||||
content=f"Observation: {m.result}"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
assert translator.to_messages(
|
||||
[
|
||||
SystemMessage(content="System"),
|
||||
AIMessage(content="Hi"),
|
||||
HumanMessage(content="Hello"),
|
||||
FunctionCallResponse(name="func", result="result"),
|
||||
]
|
||||
) == [
|
||||
SystemMessage(content="System"),
|
||||
AIMessage(content="Hi"),
|
||||
AIMessage(content="Hello"),
|
||||
HumanMessage(content="Observation: result"),
|
||||
]
|
||||
20
libs/langchain/langchain/automaton/tests/test_routers.py
Normal file
20
libs/langchain/langchain/automaton/tests/test_routers.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Sequence, List, Optional
|
||||
|
||||
from langchain.automaton.typedefs import MessageLike
|
||||
from langchain.schema import HumanMessage, AIMessage
|
||||
from langchain.schema.runnable import RunnableLambda, Runnable
|
||||
|
||||
|
||||
def test_router() -> None:
|
||||
messages = [AIMessage(content="Hello, world!")]
|
||||
|
||||
program_state = RunnableLambda(lambda x: x)
|
||||
|
||||
def route(messages: Sequence[MessageLike]) -> Optional[Runnable]:
|
||||
if isinstance(messages[-1], HumanMessage):
|
||||
return None
|
||||
else:
|
||||
return program_state
|
||||
|
||||
router = RunnableLambda(route)
|
||||
assert router.invoke(messages) == []
|
||||
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
46
libs/langchain/langchain/automaton/tests/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Iterator, List, Mapping
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
class FakeChatModel(BaseChatModel):
|
||||
"""A fake chat model that returns a pre-defined response."""
|
||||
|
||||
message_iter: Iterator[BaseMessage]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""The type of the model."""
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: List[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Generate a response to the given messages."""
|
||||
message = next(self.message_iter)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
|
||||
|
||||
def construct_func_invocation_message(
|
||||
tool: BaseTool, args: Mapping[str, Any]
|
||||
) -> AIMessage:
|
||||
"""Construct a function invocation message."""
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": tool.name,
|
||||
"arguments": json.dumps(args),
|
||||
}
|
||||
},
|
||||
)
|
||||
25
libs/langchain/langchain/automaton/tool_utils.py
Normal file
25
libs/langchain/langchain/automaton/tool_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Sequence, TypedDict
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
|
||||
def _generate_tools_descriptions(tools: Sequence[BaseTool]) -> str:
|
||||
"""Generate a description of the tools."""
|
||||
return "\n".join([f"{tool_.name}: {tool_.description}" for tool_ in tools]) + "\n"
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
"""A dictionary containing information about a tool."""
|
||||
|
||||
tool_names: str
|
||||
tools_description: str
|
||||
|
||||
|
||||
def generate_tool_info(tools: Sequence[BaseTool]) -> ToolInfo:
|
||||
"""Generate a string containing the names of the tools and their descriptions."""
|
||||
tools_description = _generate_tools_descriptions(tools)
|
||||
tool_names = ", ".join([tool_.name for tool_ in tools])
|
||||
return {
|
||||
"tool_names": tool_names,
|
||||
"tools_description": tools_description,
|
||||
}
|
||||
110
libs/langchain/langchain/automaton/typedefs.py
Normal file
110
libs/langchain/langchain/automaton/typedefs.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any, Iterator, Mapping, Optional, Sequence, Union, List
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import BaseMessage, Document
|
||||
from langchain.schema.runnable import RunnableConfig
|
||||
|
||||
|
||||
class InternalMessage(Serializable):
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Indicate whether the class is serializable."""
|
||||
return True
|
||||
|
||||
|
||||
class FunctionCallRequest(
|
||||
InternalMessage
|
||||
): # TODO(Eugene): Rename as FunctionCallRequest
|
||||
"""A request for a function invocation.
|
||||
|
||||
This message can be used to request a function invocation
|
||||
using the function name and the arguments to pass to the function.
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""The name of the function to invoke."""
|
||||
named_arguments: Optional[Mapping[str, Any]] = None
|
||||
"""The named arguments to pass to the function."""
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"FunctionCall(name={self.name}, named_arguments={self.named_arguments})"
|
||||
|
||||
|
||||
class FunctionCallResponse(InternalMessage): # Rename as FunctionCallResult
|
||||
"""A result of a function invocation."""
|
||||
|
||||
name: str
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return (
|
||||
f"FunctionResult(name={self.name}, result={self.result}, "
|
||||
f"error={self.error})"
|
||||
)
|
||||
|
||||
|
||||
class RetrievalRequest(InternalMessage):
|
||||
"""A request for a retrieval."""
|
||||
|
||||
query: str
|
||||
"""The query to use for the retrieval."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"RetrievalRequest(query={self.query})"
|
||||
|
||||
|
||||
class RetrievalResponse(InternalMessage):
|
||||
"""A result of a retrieval."""
|
||||
|
||||
results: Sequence[Document]
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"RetrievalResults(results={self.results})"
|
||||
|
||||
|
||||
class AdHocMessage(InternalMessage):
|
||||
"""A message that is used to prime the language model."""
|
||||
|
||||
type: str
|
||||
data: Any # Make sure this is serializable
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"AdHocMessage(type={self.type}, data={self.data})"
|
||||
|
||||
|
||||
class AgentFinish(InternalMessage):
|
||||
"""A message that indicates that the agent is finished."""
|
||||
|
||||
result: Any
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the object."""
|
||||
return f"AgentFinish(result={self.result})"
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessage, InternalMessage]
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def run(
|
||||
self,
|
||||
messages: Sequence[MessageLike],
|
||||
*,
|
||||
config: Optional[dict] = None,
|
||||
max_iterations: int = 100,
|
||||
) -> Iterator[MessageLike]:
|
||||
"""Run the agent."""
|
||||
raise NotImplementedError()
|
||||
@@ -79,6 +79,15 @@ class BaseMessage(Serializable):
|
||||
"""Whether this class is LangChain serializable."""
|
||||
return True
|
||||
|
||||
def __repr__(self):
|
||||
if self.additional_kwargs:
|
||||
return f"{self.type}: {self.additional_kwargs}"
|
||||
else:
|
||||
return f"{self.type}: {self.content}"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
|
||||
Reference in New Issue
Block a user