mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-23 20:23:59 +00:00
foo
This commit is contained in:
@@ -1,161 +1,21 @@
|
||||
"""Base implementation for tools or skills."""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
from abc import ABC
|
||||
from typing import Type
|
||||
|
||||
from pydantic import (
|
||||
Extra,
|
||||
Field,
|
||||
validator,
|
||||
BaseModel,
|
||||
)
|
||||
|
||||
from pydantic.generics import GenericModel
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.utilities.async_utils import async_or_sync_call
|
||||
from langchain.tools.structured import BaseStructuredTool
|
||||
|
||||
|
||||
IN_T = TypeVar("IN_T")
|
||||
OUT_T = TypeVar("OUT_T")
|
||||
|
||||
|
||||
class ToolMixin(GenericModel, Generic[IN_T, OUT_T]):
|
||||
class BaseTool(ABC, BaseStructuredTool[str, str, str]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[str] = str # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("callback_manager", pre=True, always=True)
|
||||
def set_callback_manager(
|
||||
cls, callback_manager: Optional[BaseCallbackManager]
|
||||
) -> BaseCallbackManager:
|
||||
"""If callback manager is None, set it.
|
||||
|
||||
This allows users to pass in None as callback manager, which is a nice UX.
|
||||
"""
|
||||
return callback_manager or get_callback_manager()
|
||||
|
||||
def _get_verbosity(
|
||||
self,
|
||||
verbose: Optional[bool] = None,
|
||||
) -> bool:
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
return verbose_
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
tool_input: IN_T,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUT_T:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: IN_T,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUT_T:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def __call__(self, tool_input: IN_T) -> OUT_T:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
||||
|
||||
class BaseTool(ABC, ToolMixin[str, str]):
|
||||
"""Interface LangChain tools must implement."""
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, tool_input: str) -> str:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: str,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool."""
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input,
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = self._run(tool_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
observation, verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: str,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the tool asynchronously."""
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_start,
|
||||
{"name": self.name, "description": self.description},
|
||||
tool_input,
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await self._arun(tool_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_error,
|
||||
e,
|
||||
verbose=verbose_,
|
||||
is_async=self.callback_manager.is_async,
|
||||
)
|
||||
raise e
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_end,
|
||||
observation,
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
return observation
|
||||
def _parse_input(self, tool_input: str) -> str:
|
||||
"""Load the tool's input into a pydantic model."""
|
||||
return tool_input
|
||||
|
||||
@@ -3,16 +3,13 @@ from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -21,14 +18,18 @@ from typing import (
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
create_model,
|
||||
validate_arguments,
|
||||
validator,
|
||||
)
|
||||
from pydantic.generics import GenericModel
|
||||
from pydantic.main import ModelMetaclass
|
||||
from langchain.callbacks import get_callback_manager
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
||||
from langchain.tools.base import BaseTool, ToolMixin
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities.async_utils import async_or_sync_call
|
||||
|
||||
|
||||
@@ -36,38 +37,137 @@ class SchemaAnnotationError(TypeError):
|
||||
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
|
||||
|
||||
|
||||
class ToolMetaclass(ModelMetaclass):
|
||||
"""Metaclass for BaseTool to ensure the provided args_schema
|
||||
INPUT_T = TypeVar("INPUT_T")
|
||||
SCHEMA_T = TypeVar("SCHEMA_T", bound=Union[str, BaseModel])
|
||||
OUTPUT_T = TypeVar("OUTPUT_T")
|
||||
|
||||
doesn't silently ignored."""
|
||||
|
||||
def __new__(
|
||||
cls: Type[ToolMetaclass], name: str, bases: Tuple[Type, ...], dct: dict
|
||||
) -> ToolMetaclass:
|
||||
"""Create the definition of the new tool class."""
|
||||
schema_type: Optional[Type[BaseModel]] = dct.get("args_schema")
|
||||
if schema_type is not None:
|
||||
schema_annotations = dct.get("__annotations__", {})
|
||||
args_schema_type = schema_annotations.get("args_schema", None)
|
||||
if args_schema_type is None or args_schema_type == BaseModel:
|
||||
# Throw errors for common mis-annotations.
|
||||
# TODO: Use get_args / get_origin and fully
|
||||
# specify valid annotations.
|
||||
typehint_mandate = """
|
||||
class ChildTool(BaseTool):
|
||||
...
|
||||
args_schema: Type[BaseModel] = SchemaClass
|
||||
..."""
|
||||
raise SchemaAnnotationError(
|
||||
f"StructuredTool definition for {name} must include valid type annotations"
|
||||
f" for argument 'args_schema' to behave as expected.\n"
|
||||
f"Expected annotation of 'Type[BaseModel]'"
|
||||
f" but got '{args_schema_type}'.\n"
|
||||
f"Expected class looks like:\n"
|
||||
f"{typehint_mandate}"
|
||||
)
|
||||
# Pass through to Pydantic's metaclass
|
||||
return super().__new__(cls, name, bases, dct)
|
||||
class BaseStructuredTool(
|
||||
GenericModel,
|
||||
Generic[INPUT_T, SCHEMA_T, OUTPUT_T],
|
||||
BaseModel,
|
||||
):
|
||||
"""Parent class for all structured tools."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
return_direct: bool = False
|
||||
verbose: bool = False
|
||||
callback_manager: BaseCallbackManager = Field(default_factory=get_callback_manager)
|
||||
args_schema: Type[SCHEMA_T] # :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if isinstance(self.args_schema, BaseModel):
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
return {"tool_input": "str"}
|
||||
|
||||
def _parse_input(self, tool_input: Union[INPUT_T, Any]) -> SCHEMA_T:
|
||||
"""Load the tool's input into a pydantic model."""
|
||||
if not issubclass(self.args_schema, BaseModel):
|
||||
raise ValueError(
|
||||
f"Tool with args_schema of type {self.args_schema} must overwrite _parse_input."
|
||||
)
|
||||
# Ignore type because mypy doesn't connect the subclass to the generic SCHEMA_T
|
||||
return self.args_schema.parse_obj(tool_input) # type: ignore
|
||||
|
||||
def _get_verbosity(
|
||||
self,
|
||||
verbose: Optional[bool] = None,
|
||||
) -> bool:
|
||||
if not self.verbose and verbose is not None:
|
||||
verbose_ = verbose
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
return verbose_
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, input_: SCHEMA_T) -> OUTPUT_T:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, input_: SCHEMA_T) -> OUTPUT_T:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: INPUT_T,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = self._run(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
return observation
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: INPUT_T,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> OUTPUT_T:
|
||||
"""Run the tool asynchronously."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_start,
|
||||
{"name": self.name, "description": self.description},
|
||||
str(parsed_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await self._arun(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_error,
|
||||
e,
|
||||
verbose=verbose_,
|
||||
is_async=self.callback_manager.is_async,
|
||||
)
|
||||
raise e
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_end,
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: INPUT_T) -> OUTPUT_T:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
||||
|
||||
def _create_subset_model(
|
||||
@@ -102,122 +202,14 @@ def create_schema_from_function(model_name: str, func: Callable) -> Type[BaseMod
|
||||
)
|
||||
|
||||
|
||||
class BaseStructuredTool(ToolMixin[Dict, Dict], metaclass=ToolMetaclass):
|
||||
"""Parent class for all structured tools."""
|
||||
|
||||
args_schema: Type[BaseModel] # :meta private:
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
if self.args_schema is not None:
|
||||
return self.args_schema.schema()["properties"]
|
||||
else:
|
||||
inferred_model = validate_arguments(self._run).model # type: ignore
|
||||
return get_filtered_args(inferred_model, self._run)
|
||||
|
||||
def _load_parsed_input(self, tool_input: Union[Dict, Any]) -> BaseModel:
|
||||
"""Load the tool's input into a pydantic model."""
|
||||
if not isinstance(tool_input, dict):
|
||||
# Despite being typed as a Dict, there are cases when the LLM
|
||||
# will not actually output dict args (e.g., for single arg inputs).
|
||||
single_field = next(iter(self.args_schema.__fields__))
|
||||
tool_input = {single_field: tool_input}
|
||||
return self.args_schema.parse_obj(tool_input)
|
||||
|
||||
@abstractmethod
|
||||
def _run(self, input_: BaseModel) -> Any:
|
||||
"""Use the tool."""
|
||||
|
||||
@abstractmethod
|
||||
async def _arun(self, input_: BaseModel) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Dict,
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
parsed_input = self.args_schema.parse_obj(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
self.callback_manager.on_tool_start(
|
||||
{"name": self.name, "description": self.description},
|
||||
str(tool_input),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
observation = self._run(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
self.callback_manager.on_tool_error(e, verbose=verbose_)
|
||||
raise e
|
||||
self.callback_manager.on_tool_end(
|
||||
str(observation), verbose=verbose_, color=color, name=self.name, **kwargs
|
||||
)
|
||||
if isinstance(observation, BaseModel):
|
||||
observation = observation.dict()
|
||||
return str(observation)
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
tool_input: Union[Dict, Any],
|
||||
verbose: Optional[bool] = None,
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
parsed_input = self.args_schema.parse_obj(tool_input)
|
||||
verbose_ = self._get_verbosity(verbose)
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_start,
|
||||
{"name": self.name, "description": self.description},
|
||||
str(parsed_input.dict()),
|
||||
verbose=verbose_,
|
||||
color=start_color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# We then call the tool on the tool input to get an observation
|
||||
observation = await self._run(parsed_input)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_error,
|
||||
e,
|
||||
verbose=verbose_,
|
||||
is_async=self.callback_manager.is_async,
|
||||
)
|
||||
raise e
|
||||
if isinstance(observation, BaseModel):
|
||||
observation = observation.dict()
|
||||
await async_or_sync_call(
|
||||
self.callback_manager.on_tool_end,
|
||||
str(observation),
|
||||
verbose=verbose_,
|
||||
color=color,
|
||||
is_async=self.callback_manager.is_async,
|
||||
**kwargs,
|
||||
)
|
||||
return observation
|
||||
|
||||
def __call__(self, tool_input: Dict) -> Any:
|
||||
"""Make tool callable."""
|
||||
return self.run(tool_input)
|
||||
|
||||
|
||||
class StructuredTool(BaseStructuredTool):
|
||||
class StructuredTool(BaseStructuredTool[Dict, BaseModel, Any]):
|
||||
"""StructuredTool that takes in function or coroutine directly."""
|
||||
|
||||
description: str = ""
|
||||
func: Callable[..., Any]
|
||||
"""The function to run when the tool is called."""
|
||||
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
|
||||
"""The asynchronous version of the function."""
|
||||
args_schema: Type[BaseModel] # :meta private:
|
||||
|
||||
@validator("func", pre=True, always=True)
|
||||
def validate_func_not_partial(cls, func: Callable) -> Callable:
|
||||
@@ -242,7 +234,7 @@ class StructuredTool(BaseStructuredTool):
|
||||
"""Use the tool asynchronously."""
|
||||
if self.coroutine:
|
||||
return await self.coroutine(**tool_input.dict())
|
||||
raise NotImplementedError("StructuredTool does not support async")
|
||||
raise NotImplementedError(f"StructuredTool {self.name} does not support async")
|
||||
|
||||
@classmethod
|
||||
def from_function(
|
||||
@@ -354,6 +346,3 @@ def structured_tool(
|
||||
return _partial
|
||||
else:
|
||||
raise ValueError("Too many arguments for tool decorator")
|
||||
|
||||
|
||||
TOOL_T = TypeVar("TOOL_T", bound=ToolMixin, default=BaseTool)
|
||||
|
||||
Reference in New Issue
Block a user