mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 21:20:33 +00:00
Fixed Ast Python Repl for Chatgpt multiline commands (#2406)
Resolves issue https://github.com/hwchase17/langchain/issues/2252 --------- Co-authored-by: Abhik Singla <abhiksingla@microsoft.com>
This commit is contained in:
parent
1271c00ff0
commit
955bd2e1db
@ -50,6 +50,7 @@ class PythonAstREPLTool(BaseTool):
|
|||||||
)
|
)
|
||||||
globals: Optional[Dict] = Field(default_factory=dict)
|
globals: Optional[Dict] = Field(default_factory=dict)
|
||||||
locals: Optional[Dict] = Field(default_factory=dict)
|
locals: Optional[Dict] = Field(default_factory=dict)
|
||||||
|
sanitize_input: bool = True
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def validate_python_version(cls, values: Dict) -> Dict:
|
def validate_python_version(cls, values: Dict) -> Dict:
|
||||||
@ -65,6 +66,9 @@ class PythonAstREPLTool(BaseTool):
|
|||||||
def _run(self, query: str) -> str:
|
def _run(self, query: str) -> str:
|
||||||
"""Use the tool."""
|
"""Use the tool."""
|
||||||
try:
|
try:
|
||||||
|
if self.sanitize_input:
|
||||||
|
# Remove the triple backticks from the query.
|
||||||
|
query = query.strip().strip("```")
|
||||||
tree = ast.parse(query)
|
tree = ast.parse(query)
|
||||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
"""Test functionality of Python REPL."""
|
"""Test functionality of Python REPL."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.python import PythonREPL
|
from langchain.python import PythonREPL
|
||||||
from langchain.tools.python.tool import PythonREPLTool
|
from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||||
|
|
||||||
_SAMPLE_CODE = """
|
_SAMPLE_CODE = """
|
||||||
```
|
```
|
||||||
@ -11,6 +14,14 @@ multiply()
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_AST_SAMPLE_CODE = """
|
||||||
|
```
|
||||||
|
def multiply():
|
||||||
|
return(5*6)
|
||||||
|
multiply()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_python_repl() -> None:
|
def test_python_repl() -> None:
|
||||||
"""Test functionality when globals/locals are not provided."""
|
"""Test functionality when globals/locals are not provided."""
|
||||||
@ -60,6 +71,15 @@ def test_functionality_multiline() -> None:
|
|||||||
assert output == "30\n"
|
assert output == "30\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_ast_repl_multiline() -> None:
|
||||||
|
"""Test correct functionality for ChatGPT multiline commands."""
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
pytest.skip("Python 3.9+ is required for this test")
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
output = tool.run(_AST_SAMPLE_CODE)
|
||||||
|
assert output == 30
|
||||||
|
|
||||||
|
|
||||||
def test_function() -> None:
|
def test_function() -> None:
|
||||||
"""Test correct functionality."""
|
"""Test correct functionality."""
|
||||||
chain = PythonREPL()
|
chain = PythonREPL()
|
||||||
|
Loading…
Reference in New Issue
Block a user