mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-16 01:59:52 +00:00
Compare commits
2 Commits
langchain-
...
nc/repl-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ce608ef7f | ||
|
|
d6ae25bf70 |
57
libs/repl/Makefile
Normal file
57
libs/repl/Makefile
Normal file
@@ -0,0 +1,57 @@
|
||||
.PHONY: all format lint test tests test_watch integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --now . -- tests/unit_tests
|
||||
|
||||
extended_tests:
|
||||
poetry run pytest --only-extended tests/unit_tests
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/experimental --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
|
||||
lint lint_diff:
|
||||
poetry run ruff .
|
||||
poetry run black $(PYTHON_FILES) --check
|
||||
poetry run mypy $(PYTHON_FILES)
|
||||
|
||||
format format_diff:
|
||||
poetry run black $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
||||
@echo 'test_watch - run unit tests in watch mode'
|
||||
8
libs/repl/README.md
Normal file
8
libs/repl/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# 🦜️🧪 LangChain REPL
|
||||
|
||||
This package holds utilities and tools for executing code generated by an LLM.
|
||||
|
||||
> [!WARNING]
|
||||
> Portions of the code in this package may be dangerous if not properly deployed
|
||||
> in a sandboxed environment. Please use the REPL tool that matches your
|
||||
> security needs.
|
||||
0
libs/repl/langchain_repl/__init__.py
Normal file
0
libs/repl/langchain_repl/__init__.py
Normal file
0
libs/repl/langchain_repl/py.typed
Normal file
0
libs/repl/langchain_repl/py.typed
Normal file
0
libs/repl/langchain_repl/repls/__init__.py
Normal file
0
libs/repl/langchain_repl/repls/__init__.py
Normal file
71
libs/repl/langchain_repl/repls/python.py
Normal file
71
libs/repl/langchain_repl/repls/python.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import functools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def warn_once() -> None:
|
||||
"""Warn once about the dangers of PythonREPL."""
|
||||
logger.warning("Python REPL can execute arbitrary code. Use with caution.")
|
||||
|
||||
|
||||
class PythonREPL(BaseModel):
|
||||
"""Simulates a standalone Python REPL."""
|
||||
|
||||
globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
|
||||
locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
|
||||
|
||||
@classmethod
|
||||
def worker(
|
||||
cls,
|
||||
command: str,
|
||||
globals: Optional[Dict],
|
||||
locals: Optional[Dict],
|
||||
queue: multiprocessing.Queue,
|
||||
) -> None:
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
try:
|
||||
exec(command, globals, locals)
|
||||
sys.stdout = old_stdout
|
||||
queue.put(mystdout.getvalue())
|
||||
except Exception as e:
|
||||
sys.stdout = old_stdout
|
||||
queue.put(repr(e))
|
||||
|
||||
def run(self, command: str, timeout: Optional[int] = None) -> str:
|
||||
"""Run command with own globals/locals and returns anything printed.
|
||||
Timeout after the specified number of seconds."""
|
||||
|
||||
# Warn against dangers of PythonREPL
|
||||
warn_once()
|
||||
|
||||
queue: multiprocessing.Queue = multiprocessing.Queue()
|
||||
|
||||
# Only use multiprocessing if we are enforcing a timeout
|
||||
if timeout is not None:
|
||||
# create a Process
|
||||
p = multiprocessing.Process(
|
||||
target=self.worker, args=(command, self.globals, self.locals, queue)
|
||||
)
|
||||
|
||||
# start it
|
||||
p.start()
|
||||
|
||||
# wait for the process to finish or kill it after timeout seconds
|
||||
p.join(timeout)
|
||||
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
return "Execution timed out"
|
||||
else:
|
||||
self.worker(command, self.globals, self.locals, queue)
|
||||
# get the result from the worker function
|
||||
return queue.get()
|
||||
0
libs/repl/langchain_repl/tools/__init__.py
Normal file
0
libs/repl/langchain_repl/tools/__init__.py
Normal file
149
libs/repl/langchain_repl/tools/python.py
Normal file
149
libs/repl/langchain_repl/tools/python.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
|
||||
def _get_default_python_repl() -> PythonREPL:
|
||||
return PythonREPL(_globals=globals(), _locals=None)
|
||||
|
||||
|
||||
def sanitize_input(query: str) -> str:
|
||||
"""Sanitize input to the python REPL.
|
||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||
|
||||
Args:
|
||||
query: The query to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized query
|
||||
"""
|
||||
|
||||
# Removes `, whitespace & python from start
|
||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||
# Removes whitespace & ` from end
|
||||
query = re.sub(r"(\s|`)*$", "", query)
|
||||
return query
|
||||
|
||||
|
||||
class PythonREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "Python_REPL"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"If you want to see the output of a value, you should print it out "
|
||||
"with `print(...)`."
|
||||
)
|
||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||
sanitize_input: bool = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
return self.python_repl.run(query)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self.run, query)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
query: str = Field(description="code snippet to run")
|
||||
|
||||
|
||||
class PythonAstREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "python_repl_ast"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"When using this tool, sometimes output is abbreviated - "
|
||||
"make sure it does not look abbreviated before using it in your answer."
|
||||
)
|
||||
globals: Optional[Dict] = Field(default_factory=dict)
|
||||
locals: Optional[Dict] = Field(default_factory=dict)
|
||||
sanitize_input: bool = True
|
||||
args_schema: Type[BaseModel] = PythonInputs
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_python_version(cls, values: Dict) -> Dict:
|
||||
"""Validate valid python version."""
|
||||
if sys.version_info < (3, 9):
|
||||
raise ValueError(
|
||||
"This tool relies on Python 3.9 or higher "
|
||||
"(as it uses new functionality in the `ast` module, "
|
||||
f"you have Python version: {sys.version}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
tree = ast.parse(query)
|
||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||
module_end_str = ast.unparse(module_end) # type: ignore
|
||||
io_buffer = StringIO()
|
||||
try:
|
||||
with redirect_stdout(io_buffer):
|
||||
ret = eval(module_end_str, self.globals, self.locals)
|
||||
if ret is None:
|
||||
return io_buffer.getvalue()
|
||||
else:
|
||||
return ret
|
||||
except Exception:
|
||||
with redirect_stdout(io_buffer):
|
||||
exec(module_end_str, self.globals, self.locals)
|
||||
return io_buffer.getvalue()
|
||||
except Exception as e:
|
||||
return "{}: {}".format(type(e).__name__, str(e))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, query)
|
||||
|
||||
return result
|
||||
3309
libs/repl/poetry.lock
generated
Normal file
3309
libs/repl/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
5
libs/repl/poetry.toml
Normal file
5
libs/repl/poetry.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
[virtualenvs]
|
||||
in-project = true
|
||||
|
||||
[installer]
|
||||
modern-installation = false
|
||||
78
libs/repl/pyproject.toml
Normal file
78
libs/repl/pyproject.toml
Normal file
@@ -0,0 +1,78 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-repl"
|
||||
version = "0.0.27"
|
||||
description = "Building applications with LLMs through composability"
|
||||
authors = []
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain = ">=0.0.308"
|
||||
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.0.249"
|
||||
black = "^23.1.0"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
types-pyyaml = "^6.0.12.2"
|
||||
types-requests = "^2.28.11.5"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
jupyter = "^1.0.0"
|
||||
setuptools = "^67.6.1"
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
# The only dependencies that should be added are
|
||||
# dependencies used for running tests (e.g., pytest, freezegun, response).
|
||||
# Any dependencies that do not meet that criteria will be removed.
|
||||
pytest = "^7.3.0"
|
||||
|
||||
# An extra used to be able to add extended testing.
|
||||
# Please use new-line on formatting to make it easier to add new packages without
|
||||
# merge-conflicts
|
||||
[tool.poetry.extras]
|
||||
extended_testing = [
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
ignore_missing_imports = "True"
|
||||
disallow_untyped_defs = "True"
|
||||
exclude = ["notebooks", "examples", "example_data"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"tests/*",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library"
|
||||
]
|
||||
0
libs/repl/tests/__init__.py
Normal file
0
libs/repl/tests/__init__.py
Normal file
112
libs/repl/tests/unit_tests/test_python_repl.py
Normal file
112
libs/repl/tests/unit_tests/test_python_repl.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Test functionality of Python REPL."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_repl.repls.python import PythonREPL
|
||||
from langchain_repl.tools.python import PythonAstREPLTool, PythonREPLTool
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
print(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
return(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE_EXECUTE = """
|
||||
```
|
||||
def multiply(a, b):
|
||||
return(5*6)
|
||||
a = 5
|
||||
b = 6
|
||||
|
||||
multiply(a, b)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def test_python_repl() -> None:
|
||||
"""Test functionality when globals/locals are not provided."""
|
||||
repl = PythonREPL()
|
||||
|
||||
# Run a simple initial command.
|
||||
repl.run("foo = 1")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["foo"] == 1
|
||||
|
||||
# Now run a command that accesses `foo` to make sure it still has it.
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 2
|
||||
|
||||
|
||||
def test_python_repl_no_previous_variables() -> None:
|
||||
"""Test that it does not have access to variables created outside the scope."""
|
||||
foo = 3 # noqa: F841
|
||||
repl = PythonREPL()
|
||||
output = repl.run("print(foo)")
|
||||
assert output == """NameError("name 'foo' is not defined")"""
|
||||
|
||||
|
||||
def test_python_repl_pass_in_locals() -> None:
|
||||
"""Test functionality when passing in locals."""
|
||||
_locals = {"foo": 4}
|
||||
repl = PythonREPL(_locals=_locals)
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 8
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "print(1 + 1)"
|
||||
output = chain.run(code)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_functionality_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
chain = PythonREPL()
|
||||
tool = PythonREPLTool(python_repl=chain)
|
||||
output = tool.run(_SAMPLE_CODE)
|
||||
assert output == "30\n"
|
||||
|
||||
|
||||
def test_python_ast_repl_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_python_ast_repl_multi_statement() -> None:
|
||||
"""Test correct functionality for ChatGPT multi statement commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_function() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "def add(a, b): " " return a + b"
|
||||
output = chain.run(code)
|
||||
assert output == ""
|
||||
|
||||
code = "print(add(1, 2))"
|
||||
output = chain.run(code)
|
||||
assert output == "3\n"
|
||||
164
libs/repl/tests/unit_tests/test_python_tool.py
Normal file
164
libs/repl/tests/unit_tests/test_python_tool.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Test Python REPL Tools."""
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_repl.tools.python import (
|
||||
PythonAstREPLTool,
|
||||
PythonREPLTool,
|
||||
sanitize_input,
|
||||
)
|
||||
|
||||
|
||||
def test_python_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||
|
||||
|
||||
def test_python_repl_print() -> None:
|
||||
program = """
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
print("The dot product is {:d}.".format(dot_product))
|
||||
"""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.run(program) == "The dot product is 32.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert tool.run("1 + 1") == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_return() -> None:
|
||||
program = """
|
||||
```
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == 32
|
||||
|
||||
program = """
|
||||
```python
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
assert tool.run(program) == 32
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_print() -> None:
|
||||
program = """python
|
||||
string = "racecar"
|
||||
if string == string[::-1]:
|
||||
print(string, "is a palindrome")
|
||||
else:
|
||||
print(string, "is not a palindrome")"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "racecar is a palindrome\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_repl_print_python_backticks() -> None:
|
||||
program = "`print('`python` is a great language.')`"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "`python` is a great language.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_raise_exception() -> None:
|
||||
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
||||
program = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(data)
|
||||
df['Gender']
|
||||
"""
|
||||
tool = PythonAstREPLTool(locals={"data": data})
|
||||
expected_outputs = (
|
||||
"KeyError: 'Gender'",
|
||||
"ModuleNotFoundError: No module named 'pandas'",
|
||||
)
|
||||
assert tool.run(program) in expected_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_print() -> None:
|
||||
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "The square of 3 is 9.00\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_return() -> None:
|
||||
arr = np.array([1, 2, 3, 4, 5])
|
||||
tool = PythonAstREPLTool(locals={"arr": arr})
|
||||
program = "`(arr**2).sum() # Returns sum of squares`"
|
||||
assert tool.run(program) == 55
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_exception() -> None:
|
||||
program = "[1, 2, 3][4]"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "IndexError: list index out of range"
|
||||
|
||||
|
||||
def test_sanitize_input() -> None:
|
||||
query = """
|
||||
```
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
```python
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
p = 5
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
Reference in New Issue
Block a user