mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
Improved query, print & exception handling in REPL Tool (#4997)
Update to pull request https://github.com/hwchase17/langchain/pull/3215 Summary: 1) Improved the sanitization of query (using regex), by removing python command (since gpt-3.5-turbo sometimes assumes python console as a terminal, and runs python command first which causes error). Also sometimes 1 line python codes contain single backticks. 2) Added 7 new test cases. For more details, view the previous pull request. --------- Co-authored-by: Deepak S V <svdeepak99@users.noreply.github.com>
This commit is contained in:
parent
785502edb3
commit
49ca02711e
@ -1,7 +1,9 @@
|
|||||||
"""A tool for running python code in a REPL."""
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import redirect_stdout
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -19,14 +21,13 @@ def _get_default_python_repl() -> PythonREPL:
|
|||||||
return PythonREPL(_globals=globals(), _locals=None)
|
return PythonREPL(_globals=globals(), _locals=None)
|
||||||
|
|
||||||
|
|
||||||
_MD_PY_BLOCK = "```python"
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_input(query: str) -> str:
|
def sanitize_input(query: str) -> str:
|
||||||
query = query.strip()
|
# Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||||
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
|
|
||||||
query = query[len(_MD_PY_BLOCK) :].strip()
|
# Removes `, whitespace & python from start
|
||||||
query = query.strip("`").strip()
|
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||||
|
# Removes whitespace & ` from end
|
||||||
|
query = re.sub(r"(\s|`)*$", "", query)
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
@ -101,19 +102,18 @@ class PythonAstREPLTool(BaseTool):
|
|||||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||||
module_end_str = ast.unparse(module_end) # type: ignore
|
module_end_str = ast.unparse(module_end) # type: ignore
|
||||||
|
io_buffer = StringIO()
|
||||||
try:
|
try:
|
||||||
return eval(module_end_str, self.globals, self.locals)
|
with redirect_stdout(io_buffer):
|
||||||
|
ret = eval(module_end_str, self.globals, self.locals)
|
||||||
|
if ret is None:
|
||||||
|
return io_buffer.getvalue()
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
except Exception:
|
except Exception:
|
||||||
old_stdout = sys.stdout
|
with redirect_stdout(io_buffer):
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
try:
|
|
||||||
exec(module_end_str, self.globals, self.locals)
|
exec(module_end_str, self.globals, self.locals)
|
||||||
sys.stdout = old_stdout
|
return io_buffer.getvalue()
|
||||||
output = mystdout.getvalue()
|
|
||||||
except Exception as e:
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = repr(e)
|
|
||||||
return output
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return "{}: {}".format(type(e).__name__, str(e))
|
return "{}: {}".format(type(e).__name__, str(e))
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Test Python REPL Tools."""
|
"""Test Python REPL Tools."""
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.tools.python.tool import (
|
from langchain.tools.python.tool import (
|
||||||
@ -17,6 +18,18 @@ def test_python_repl_tool_single_input() -> None:
|
|||||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_print() -> None:
|
||||||
|
program = """
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
print("The dot product is {:d}.".format(dot_product))
|
||||||
|
"""
|
||||||
|
tool = PythonREPLTool()
|
||||||
|
assert tool.run(program) == "The dot product is 32.\n"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
)
|
)
|
||||||
@ -27,6 +40,103 @@ def test_python_ast_repl_tool_single_input() -> None:
|
|||||||
assert tool.run("1 + 1") == 2
|
assert tool.run("1 + 1") == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_return() -> None:
|
||||||
|
program = """
|
||||||
|
```
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
int(dot_product)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == 32
|
||||||
|
|
||||||
|
program = """
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
int(dot_product)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
assert tool.run(program) == 32
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_print() -> None:
|
||||||
|
program = """python
|
||||||
|
string = "racecar"
|
||||||
|
if string == string[::-1]:
|
||||||
|
print(string, "is a palindrome")
|
||||||
|
else:
|
||||||
|
print(string, "is not a palindrome")"""
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "racecar is a palindrome\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_repl_print_python_backticks() -> None:
|
||||||
|
program = "`print('`python` is a great language.')`"
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "`python` is a great language.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_raise_exception() -> None:
|
||||||
|
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
||||||
|
program = """
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df['Gender']
|
||||||
|
"""
|
||||||
|
tool = PythonAstREPLTool(locals={"data": data})
|
||||||
|
expected_outputs = (
|
||||||
|
"KeyError: 'Gender'",
|
||||||
|
"ModuleNotFoundError: No module named 'pandas'",
|
||||||
|
)
|
||||||
|
assert tool.run(program) in expected_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_print() -> None:
|
||||||
|
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "The square of 3 is 9.00\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_return() -> None:
|
||||||
|
arr = np.array([1, 2, 3, 4, 5])
|
||||||
|
tool = PythonAstREPLTool(locals={"arr": arr})
|
||||||
|
program = "`(arr**2).sum() # Returns sum of squares`"
|
||||||
|
assert tool.run(program) == 55
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_exception() -> None:
|
||||||
|
program = "[1, 2, 3][4]"
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "IndexError: list index out of range"
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_input() -> None:
|
def test_sanitize_input() -> None:
|
||||||
query = """
|
query = """
|
||||||
```
|
```
|
||||||
|
Loading…
Reference in New Issue
Block a user