Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
ec73265930 wip 2024-01-12 10:25:58 -08:00
3 changed files with 65 additions and 5 deletions

View File

@@ -10,15 +10,12 @@ from langchain_core.messages import (
FunctionMessage,
HumanMessage,
)
from langchain_core.tools import ToolInvocation
class AgentAction(Serializable):
class AgentAction(ToolInvocation):
"""A full description of an action for an ActionAgent to execute."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
log: str
"""Additional information to log about the action.
This log can be used in a few ways. First, it can be used to audit

View File

@@ -15,6 +15,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
Callbacks,
)
from langchain_core.load import Serializable
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import (
BaseModel,
@@ -37,6 +38,14 @@ class SchemaAnnotationError(TypeError):
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
class ToolInvocation(Serializable):
"""The name and input for a tool to invoke."""
tool: str
"""The name of the Tool to execute."""
tool_input: Union[str, dict]
"""The input to pass in to the Tool."""
def _create_subset_model(
name: str, model: BaseModel, field_names: list
) -> Type[BaseModel]:
@@ -845,3 +854,4 @@ def tool(
return _partial
else:
raise ValueError("Too many arguments for tool decorator")

View File

@@ -0,0 +1,53 @@
from typing import Any, Sequence
from langchain_core.runnables.base import RunnableBindingBase, RunnableLambda
from langchain_core.tools import BaseTool, ToolInvocation
class InvalidToolNameException(Exception):
def __init__(
self,
requested: str,
available: Sequence[str],
) -> None:
self.requested = requested
self.available = available
def __str__(self) -> str:
available_str = ", ".join(t for t in self.available)
return (
f"{self.__class__.__name__}: Requested tool {self.requested} does not "
f"exist. Available tools are {available_str}."
)
class ToolExecutor(RunnableBindingBase):
def __init__(self, tools: Sequence[BaseTool], **kwargs: Any) -> None:
self._tools = tools
self._tool_map = {t.name: t for t in self._tools}
bound = RunnableLambda(self._execute, afunc=self._aexecute)
super().__init__(bound=bound, **kwargs)
def _execute(self, tool_invocation: ToolInvocation) -> Any:
if tool_invocation.tool not in self._tool_map:
exception = InvalidToolNameException(
requested=tool_invocation.tool,
available=list(self._tool_map.keys()),
)
return tool_invocation, exception
else:
tool = self._tool_map[tool_invocation.tool]
output = tool.invoke(tool_invocation.tool_input)
return tool_invocation, output
async def _aexecute(self, tool_invocation: ToolInvocation) -> Any:
if tool_invocation.tool not in self._tool_map:
exception = InvalidToolNameException(
requested=tool_invocation.tool,
available=list(self._tool_map.keys()),
)
return tool_invocation, exception
else:
tool = self._tool_map[tool_invocation.tool]
output = await tool.ainvoke(tool_invocation.tool_input)
return tool_invocation, output