mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 18:38:48 +00:00
Cadlabs/python tool sanitization (#4754)
Co-authored-by: BenSchZA <BenSchZA@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,17 @@ def _get_default_python_repl() -> PythonREPL:
|
||||
return PythonREPL(_globals=globals(), _locals=None)
|
||||
|
||||
|
||||
_MD_PY_BLOCK = "```python"
|
||||
|
||||
|
||||
def sanitize_input(query: str) -> str:
|
||||
query = query.strip()
|
||||
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
|
||||
query = query[len(_MD_PY_BLOCK) :].strip()
|
||||
query = query.strip("`").strip()
|
||||
return query
|
||||
|
||||
|
||||
class PythonREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
@@ -39,7 +50,7 @@ class PythonREPLTool(BaseTool):
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
query = query.strip().strip("```")
|
||||
query = sanitize_input(query)
|
||||
return self.python_repl.run(query)
|
||||
|
||||
async def _arun(
|
||||
@@ -84,8 +95,7 @@ class PythonAstREPLTool(BaseTool):
|
||||
"""Use the tool."""
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
# Remove the triple backticks from the query.
|
||||
query = query.strip().strip("```")
|
||||
query = sanitize_input(query)
|
||||
tree = ast.parse(query)
|
||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||
|
@@ -3,7 +3,11 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||
from langchain.tools.python.tool import (
|
||||
PythonAstREPLTool,
|
||||
PythonREPLTool,
|
||||
sanitize_input,
|
||||
)
|
||||
|
||||
|
||||
def test_python_repl_tool_single_input() -> None:
|
||||
@@ -21,3 +25,30 @@ def test_python_ast_repl_tool_single_input() -> None:
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert tool.run("1 + 1") == 2
|
||||
|
||||
|
||||
def test_sanitize_input() -> None:
|
||||
query = """
|
||||
```
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
```python
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
p = 5
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
Reference in New Issue
Block a user