mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 12:07:36 +00:00
Cadlabs/python tool sanitization (#4754)
Co-authored-by: BenSchZA <BenSchZA@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,17 @@ def _get_default_python_repl() -> PythonREPL:
|
|||||||
return PythonREPL(_globals=globals(), _locals=None)
|
return PythonREPL(_globals=globals(), _locals=None)
|
||||||
|
|
||||||
|
|
||||||
|
_MD_PY_BLOCK = "```python"
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_input(query: str) -> str:
|
||||||
|
query = query.strip()
|
||||||
|
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
|
||||||
|
query = query[len(_MD_PY_BLOCK) :].strip()
|
||||||
|
query = query.strip("`").strip()
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
class PythonREPLTool(BaseTool):
|
class PythonREPLTool(BaseTool):
|
||||||
"""A tool for running python code in a REPL."""
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
@@ -39,7 +50,7 @@ class PythonREPLTool(BaseTool):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
if self.sanitize_input:
|
if self.sanitize_input:
|
||||||
query = query.strip().strip("```")
|
query = sanitize_input(query)
|
||||||
return self.python_repl.run(query)
|
return self.python_repl.run(query)
|
||||||
|
|
||||||
async def _arun(
|
async def _arun(
|
||||||
@@ -84,8 +95,7 @@ class PythonAstREPLTool(BaseTool):
|
|||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
try:
|
try:
|
||||||
if self.sanitize_input:
|
if self.sanitize_input:
|
||||||
# Remove the triple backticks from the query.
|
query = sanitize_input(query)
|
||||||
query = query.strip().strip("```")
|
|
||||||
tree = ast.parse(query)
|
tree = ast.parse(query)
|
||||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||||
|
@@ -3,7 +3,11 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
from langchain.tools.python.tool import (
|
||||||
|
PythonAstREPLTool,
|
||||||
|
PythonREPLTool,
|
||||||
|
sanitize_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl_tool_single_input() -> None:
|
def test_python_repl_tool_single_input() -> None:
|
||||||
@@ -21,3 +25,30 @@ def test_python_ast_repl_tool_single_input() -> None:
|
|||||||
tool = PythonAstREPLTool()
|
tool = PythonAstREPLTool()
|
||||||
assert tool.is_single_input
|
assert tool.is_single_input
|
||||||
assert tool.run("1 + 1") == 2
|
assert tool.run("1 + 1") == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_input() -> None:
|
||||||
|
query = """
|
||||||
|
```
|
||||||
|
p = 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
query = """
|
||||||
|
```python
|
||||||
|
p = 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
query = """
|
||||||
|
p = 5
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
||||||
|
Reference in New Issue
Block a user