experimental: clean python repl input(experimental:Added code for PythonREPL) (#20930)

Update python.py(experimental:Added code for PythonREPL)

Added code for PythonREPL, defining a static method 'sanitize_input'
that takes the string 'query' as input and returns a sanitizing string.
The purpose of this method is to remove unwanted characters from the
input string, Specifically:

1. Delete the whitespace at the beginning and end of the string (' \s').
2. Remove the quotation marks (`` ` ``) at the beginning and end of the
string.
3. Remove the keyword "python" at the beginning of the string (case
insensitive) because the user may have typed it.

This method uses regular expressions (regex) to implement sanitizing.

It all started with this code:
from langchain.agents import Tool
from langchain_experimental.utilities import PythonREPL

python_repl = PythonREPL()
repl_tool = Tool(
    name="python_repl",
description="Remove redundant formatting marks at the beginning and end
of source code from input.Use a Python shell to execute python commands.
If you want to see the output of a value, you should print it out with
`print(...)`.",
    func=python_repl.run,
)

When I call the agent to write a piece of code for me and execute it
with the defined code, I must get an error: SyntaxError('invalid
syntax', ('<string>', 1, 1,'In', 1, 2))

After checking, I found that pythonREPL has less formatting of input
code than the soon-to-be deprecated pythonREPL tool, so I added this
step to it, so that no matter what code I ask the agent to write for me,
it can be executed smoothly and get the output result.
I have tried modifying the prompt words to solve this problem before,
but it did not work, and by adding a simple format check, the problem is
well resolved.
<img width="1271" alt="image"
src="https://github.com/langchain-ai/langchain/assets/164149097/c49a685f-d246-4b11-b655-fd952fc2f04c">

---------

Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Liu Xiaodong 2024-05-01 13:19:09 +08:00 committed by GitHub
parent 1fdf63fa6c
commit 3b473d10f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 1 deletions

View File

@ -1,6 +1,7 @@
import functools
import logging
import multiprocessing
import re
import sys
from io import StringIO
from typing import Dict, Optional
@ -22,6 +23,23 @@ class PythonREPL(BaseModel):
globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
@staticmethod
def sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python
(if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
query = re.sub(r"(\s|`)*$", "", query)
return query
@classmethod
def worker(
cls,
@ -33,7 +51,8 @@ class PythonREPL(BaseModel):
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
try:
exec(command, globals, locals)
cleaned_command = cls.sanitize_input(command)
exec(cleaned_command, globals, locals)
sys.stdout = old_stdout
queue.put(mystdout.getvalue())
except Exception as e:

View File

@ -0,0 +1,34 @@
import unittest
from langchain_experimental.utilities import PythonREPL
class TestSanitizeInput(unittest.TestCase):
def test_whitespace_removal(self) -> None:
query = " print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_python_removal(self) -> None:
query = "python print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_backtick_removal(self) -> None:
query = "`print('Hello, world!')`"
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_combined_removal(self) -> None:
query = " `python print('Hello, world!')` "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
def test_mixed_case_removal(self) -> None:
query = " pYtHoN print('Hello, world!') "
sanitized_query = PythonREPL.sanitize_input(query)
self.assertEqual(sanitized_query, "print('Hello, world!')")
if __name__ == "__main__":
unittest.main()