mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-22 07:05:36 +00:00
Callbacks Refactor [base] (#3256)
Co-authored-by: Nuno Campos <nuno@boringbits.io> Co-authored-by: Davis Chase <130488702+dev2049@users.noreply.github.com> Co-authored-by: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -3,10 +3,14 @@
|
||||
import ast
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
@@ -28,13 +32,23 @@ class PythonREPLTool(BaseTool):
|
||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||
sanitize_input: bool = True
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
query = query.strip().strip("```")
|
||||
return self.python_repl.run(query)
|
||||
|
||||
async def _arun(self, query: str) -> str:
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("PythonReplTool does not support async")
|
||||
|
||||
@@ -64,7 +78,11 @@ class PythonAstREPLTool(BaseTool):
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
@@ -91,6 +109,10 @@ class PythonAstREPLTool(BaseTool):
|
||||
except Exception as e:
|
||||
return "{}: {}".format(type(e).__name__, str(e))
|
||||
|
||||
async def _arun(self, query: str) -> str:
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
raise NotImplementedError("PythonReplTool does not support async")
|
||||
|
||||
Reference in New Issue
Block a user