Add python,pandas,xorbits,spark agents to experimental (#11774)

See for contex
https://github.com/langchain-ai/langchain/discussions/11680
This commit is contained in:
Eugene Yurtsev 2023-10-13 17:36:44 -04:00 committed by GitHub
parent d6e34ca2ee
commit 0d37b4c27d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1171 additions and 0 deletions

View File

@ -0,0 +1 @@
"""Pandas toolkit."""

View File

@ -0,0 +1,341 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional, Sequence, Tuple
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import SystemMessage
from langchain.tools import BaseTool
from langchain_experimental.agents.agent_toolkits.pandas.prompt import (
FUNCTIONS_WITH_DF,
FUNCTIONS_WITH_MULTI_DF,
MULTI_DF_PREFIX,
MULTI_DF_PREFIX_FUNCTIONS,
PREFIX,
PREFIX_FUNCTIONS,
SUFFIX_NO_DF,
SUFFIX_WITH_DF,
SUFFIX_WITH_MULTI_DF,
)
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def _get_multi_prompt(
dfs: List[Any],
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
num_dfs = len(dfs)
if suffix is not None:
suffix_to_use = suffix
include_dfs_head = True
elif include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_MULTI_DF
include_dfs_head = True
else:
suffix_to_use = SUFFIX_NO_DF
include_dfs_head = False
if input_variables is None:
input_variables = ["input", "agent_scratchpad", "num_dfs"]
if include_dfs_head:
input_variables += ["dfs_head"]
if prefix is None:
prefix = MULTI_DF_PREFIX
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
)
partial_prompt = prompt.partial()
if "dfs_head" in input_variables:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head)
if "num_dfs" in input_variables:
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs))
return partial_prompt, tools
def _get_single_prompt(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
include_df_head = True
elif include_df_in_prompt:
suffix_to_use = SUFFIX_WITH_DF
include_df_head = True
else:
suffix_to_use = SUFFIX_NO_DF
include_df_head = False
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
if include_df_head:
input_variables += ["df_head"]
if prefix is None:
prefix = PREFIX
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
)
partial_prompt = prompt.partial()
if "df_head" in input_variables:
partial_prompt = partial_prompt.partial(
df_head=str(df.head(number_of_head_rows).to_markdown())
)
return partial_prompt, tools
def _get_prompt_and_tools(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
try:
import pandas as pd
pd.set_option("display.max_columns", None)
except ImportError:
raise ImportError(
"pandas package not found, please install with `pip install pandas`"
)
if include_df_in_prompt is not None and suffix is not None:
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
if isinstance(df, list):
for item in df:
if not isinstance(item, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_multi_prompt(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_single_prompt(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
def _get_functions_single_prompt(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
suffix_to_use = suffix_to_use.format(
df_head=str(df.head(number_of_head_rows).to_markdown())
)
elif include_df_in_prompt:
suffix_to_use = FUNCTIONS_WITH_DF.format(
df_head=str(df.head(number_of_head_rows).to_markdown())
)
else:
suffix_to_use = ""
if prefix is None:
prefix = PREFIX_FUNCTIONS
tools = [PythonAstREPLTool(locals={"df": df})]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_multi_prompt(
dfs: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
dfs_head = "\n\n".join(
[d.head(number_of_head_rows).to_markdown() for d in dfs]
)
suffix_to_use = suffix_to_use.format(
dfs_head=dfs_head,
)
elif include_df_in_prompt:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
dfs_head=dfs_head,
)
else:
suffix_to_use = ""
if prefix is None:
prefix = MULTI_DF_PREFIX_FUNCTIONS
prefix = prefix.format(num_dfs=str(len(dfs)))
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
def _get_functions_prompt_and_tools(
df: Any,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
try:
import pandas as pd
pd.set_option("display.max_columns", None)
except ImportError:
raise ImportError(
"pandas package not found, please install with `pip install pandas`"
)
if input_variables is not None:
raise ValueError("`input_variables` is not supported at the moment.")
if include_df_in_prompt is not None and suffix is not None:
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
if isinstance(df, list):
for item in df:
if not isinstance(item, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_multi_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_single_prompt(
df,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
def create_pandas_dataframe_agent(
llm: BaseLanguageModel,
df: Any,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
extra_tools: Sequence[BaseTool] = (),
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, base_tools = _get_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
tools = base_tools + list(extra_tools)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
_prompt, base_tools = _get_functions_prompt_and_tools(
df,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
tools = base_tools + list(extra_tools)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -0,0 +1,44 @@
# flake8: noqa
PREFIX = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:"""
MULTI_DF_PREFIX = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
should use the tools below to answer the question posed of you:"""
SUFFIX_NO_DF = """
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}
Begin!
Question: {input}
{agent_scratchpad}"""
SUFFIX_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}
Begin!
Question: {input}
{agent_scratchpad}"""
PREFIX_FUNCTIONS = """
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
MULTI_DF_PREFIX_FUNCTIONS = """
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
FUNCTIONS_WITH_DF = """
This is the result of `print(df.head())`:
{df_head}"""
FUNCTIONS_WITH_MULTI_DF = """
This is the result of `print(df.head())` for each dataframe:
{dfs_head}"""

View File

@ -0,0 +1,59 @@
"""Python agent."""
from typing import Any, Dict, Optional
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.types import AgentType
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import SystemMessage
from langchain_experimental.agents.agent_toolkits.python.prompt import PREFIX
from langchain_experimental.tools.python.tool import PythonREPLTool
def create_python_agent(
llm: BaseLanguageModel,
tool: PythonREPLTool,
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
callback_manager: Optional[BaseCallbackManager] = None,
verbose: bool = False,
prefix: str = PREFIX,
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a python agent from an LLM and tool."""
tools = [tool]
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
elif agent_type == AgentType.OPENAI_FUNCTIONS:
system_message = SystemMessage(content=prefix)
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)
else:
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
**(agent_executor_kwargs or {}),
)

View File

@ -0,0 +1,9 @@
# flake8: noqa
PREFIX = """You are an agent designed to write and execute python code to answer questions.
You have access to a python REPL, which you can use to execute python code.
If you get an error, debug your code and try again.
Only use the output of your code to answer the question.
You might know the answer without running any code, but you should still run the code to get the answer.
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
"""

View File

@ -0,0 +1 @@
"""spark toolkit"""

View File

@ -0,0 +1,81 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def _validate_spark_df(df: Any) -> bool:
try:
from pyspark.sql import DataFrame as SparkLocalDataFrame
return isinstance(df, SparkLocalDataFrame)
except ImportError:
return False
def _validate_spark_connect_df(df: Any) -> bool:
try:
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
return isinstance(df, SparkConnectDataFrame)
except ImportError:
return False
def create_spark_dataframe_agent(
llm: BaseLLM,
df: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = PREFIX,
suffix: str = SUFFIX,
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a Spark agent from an LLM and dataframe."""
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
raise ImportError("Spark is not installed. run `pip install pyspark`.")
if input_variables is None:
input_variables = ["df", "input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
)
partial_prompt = prompt.partial(df=str(df.first()))
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -0,0 +1,13 @@
# flake8: noqa
PREFIX = """
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
You should use the tools below to answer the question posed of you:"""
SUFFIX = """
This is the result of `print(df.first())`:
{df}
Begin!
Question: {input}
{agent_scratchpad}"""

View File

@ -0,0 +1 @@
"""Xorbits toolkit."""

View File

@ -0,0 +1,91 @@
"""Agent for working with xorbits objects."""
from typing import Any, Dict, List, Optional
from langchain.agents.agent import AgentExecutor
from langchain.agents.mrkl.base import ZeroShotAgent
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain_experimental.agents.agent_toolkits.xorbits.prompt import (
NP_PREFIX,
NP_SUFFIX,
PD_PREFIX,
PD_SUFFIX,
)
from langchain_experimental.tools.python.tool import PythonAstREPLTool
def create_xorbits_agent(
llm: BaseLLM,
data: Any,
callback_manager: Optional[BaseCallbackManager] = None,
prefix: str = "",
suffix: str = "",
input_variables: Optional[List[str]] = None,
verbose: bool = False,
return_intermediate_steps: bool = False,
max_iterations: Optional[int] = 15,
max_execution_time: Optional[float] = None,
early_stopping_method: str = "force",
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a xorbits agent from an LLM and dataframe."""
try:
from xorbits import numpy as np
from xorbits import pandas as pd
except ImportError:
raise ImportError(
"Xorbits package not installed, please install with `pip install xorbits`"
)
if not isinstance(data, (pd.DataFrame, np.ndarray)):
raise ValueError(
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
)
if input_variables is None:
input_variables = ["data", "input", "agent_scratchpad"]
tools = [PythonAstREPLTool(locals={"data": data})]
prompt, partial_input = None, None
if isinstance(data, pd.DataFrame):
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=PD_PREFIX if prefix == "" else prefix,
suffix=PD_SUFFIX if suffix == "" else suffix,
input_variables=input_variables,
)
partial_input = str(data.head())
else:
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=NP_PREFIX if prefix == "" else prefix,
suffix=NP_SUFFIX if suffix == "" else suffix,
input_variables=input_variables,
)
partial_input = str(data[: len(data) // 2])
partial_prompt = prompt.partial(data=partial_input)
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(
llm_chain=llm_chain,
allowed_tools=tool_names,
callback_manager=callback_manager,
**kwargs,
)
return AgentExecutor.from_agent_and_tools(
agent=agent,
tools=tools,
callback_manager=callback_manager,
verbose=verbose,
return_intermediate_steps=return_intermediate_steps,
max_iterations=max_iterations,
max_execution_time=max_execution_time,
early_stopping_method=early_stopping_method,
**(agent_executor_kwargs or {}),
)

View File

@ -0,0 +1,33 @@
PD_PREFIX = """
You are working with Xorbits dataframe object in Python.
Before importing Numpy or Pandas in the current script,
remember to import the xorbits version of the library instead.
To import the xorbits version of Numpy, replace the original import statement
`import pandas as pd` with `import xorbits.pandas as pd`.
The name of the input is `data`.
You should use the tools below to answer the question posed of you:"""
PD_SUFFIX = """
This is the result of `print(data)`:
{data}
Begin!
Question: {input}
{agent_scratchpad}"""
NP_PREFIX = """
You are working with Xorbits ndarray object in Python.
Before importing Numpy in the current script,
remember to import the xorbits version of the library instead.
To import the xorbits version of Numpy, replace the original import statement
`import numpy as np` with `import xorbits.numpy as np`.
The name of the input is `data`.
You should use the tools below to answer the question posed of you:"""
NP_SUFFIX = """
This is the result of `print(data)`:
{data}
Begin!
Question: {input}
{agent_scratchpad}"""

View File

@ -0,0 +1,150 @@
"""A tool for running python code in a REPL."""
import ast
import asyncio
import re
import sys
from contextlib import redirect_stdout
from io import StringIO
from typing import Any, Dict, Optional, Type
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.tools.base import BaseTool
from langchain_experimental.utilities.python import PythonREPL
def _get_default_python_repl() -> PythonREPL:
return PythonREPL(_globals=globals(), _locals=None)
def sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
query: The query to sanitize
Returns:
str: The sanitized query
"""
# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query
class PythonREPLTool(BaseTool):
"""A tool for running python code in a REPL."""
name: str = "Python_REPL"
description: str = (
"A Python shell. Use this to execute python commands. "
"Input should be a valid python command. "
"If you want to see the output of a value, you should print it out "
"with `print(...)`."
)
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
sanitize_input: bool = True
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Any:
"""Use the tool."""
if self.sanitize_input:
query = sanitize_input(query)
return self.python_repl.run(query)
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Any:
"""Use the tool asynchronously."""
if self.sanitize_input:
query = sanitize_input(query)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self.run, query)
return result
class PythonInputs(BaseModel):
query: str = Field(description="code snippet to run")
class PythonAstREPLTool(BaseTool):
"""A tool for running python code in a REPL."""
name: str = "python_repl_ast"
description: str = (
"A Python shell. Use this to execute python commands. "
"Input should be a valid python command. "
"When using this tool, sometimes output is abbreviated - "
"make sure it does not look abbreviated before using it in your answer."
)
globals: Optional[Dict] = Field(default_factory=dict)
locals: Optional[Dict] = Field(default_factory=dict)
sanitize_input: bool = True
args_schema: Type[BaseModel] = PythonInputs
@root_validator(pre=True)
def validate_python_version(cls, values: Dict) -> Dict:
"""Validate valid python version."""
if sys.version_info < (3, 9):
raise ValueError(
"This tool relies on Python 3.9 or higher "
"(as it uses new functionality in the `ast` module, "
f"you have Python version: {sys.version}"
)
return values
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
try:
if self.sanitize_input:
query = sanitize_input(query)
tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[])
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
module_end = ast.Module(tree.body[-1:], type_ignores=[])
module_end_str = ast.unparse(module_end) # type: ignore
io_buffer = StringIO()
try:
with redirect_stdout(io_buffer):
ret = eval(module_end_str, self.globals, self.locals)
if ret is None:
return io_buffer.getvalue()
else:
return ret
except Exception:
with redirect_stdout(io_buffer):
exec(module_end_str, self.globals, self.locals)
return io_buffer.getvalue()
except Exception as e:
return "{}: {}".format(type(e).__name__, str(e))
async def _arun(
self,
query: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Any:
"""Use the tool asynchronously."""
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, query)
return result

View File

@ -0,0 +1,71 @@
import functools
import logging
import multiprocessing
import sys
from io import StringIO
from typing import Dict, Optional
from langchain.pydantic_v1 import BaseModel, Field
logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None)
def warn_once() -> None:
"""Warn once about the dangers of PythonREPL."""
logger.warning("Python REPL can execute arbitrary code. Use with caution.")
class PythonREPL(BaseModel):
"""Simulates a standalone Python REPL."""
globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
@classmethod
def worker(
cls,
command: str,
globals: Optional[Dict],
locals: Optional[Dict],
queue: multiprocessing.Queue,
) -> None:
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
try:
exec(command, globals, locals)
sys.stdout = old_stdout
queue.put(mystdout.getvalue())
except Exception as e:
sys.stdout = old_stdout
queue.put(repr(e))
def run(self, command: str, timeout: Optional[int] = None) -> str:
"""Run command with own globals/locals and returns anything printed.
Timeout after the specified number of seconds."""
# Warn against dangers of PythonREPL
warn_once()
queue: multiprocessing.Queue = multiprocessing.Queue()
# Only use multiprocessing if we are enforcing a timeout
if timeout is not None:
# create a Process
p = multiprocessing.Process(
target=self.worker, args=(command, self.globals, self.locals, queue)
)
# start it
p.start()
# wait for the process to finish or kill it after timeout seconds
p.join(timeout)
if p.is_alive():
p.terminate()
return "Execution timed out"
else:
self.worker(command, self.globals, self.locals, queue)
# get the result from the worker function
return queue.get()

View File

@ -0,0 +1,112 @@
"""Test functionality of Python REPL."""
import sys
import pytest
from langchain_experimental.tools.python.tool import PythonAstREPLTool, PythonREPLTool
from langchain_experimental.utilities.python import PythonREPL
_SAMPLE_CODE = """
```
def multiply():
print(5*6)
multiply()
```
"""
_AST_SAMPLE_CODE = """
```
def multiply():
return(5*6)
multiply()
```
"""
_AST_SAMPLE_CODE_EXECUTE = """
```
def multiply(a, b):
return(5*6)
a = 5
b = 6
multiply(a, b)
```
"""
def test_python_repl() -> None:
"""Test functionality when globals/locals are not provided."""
repl = PythonREPL()
# Run a simple initial command.
repl.run("foo = 1")
assert repl.locals is not None
assert repl.locals["foo"] == 1
# Now run a command that accesses `foo` to make sure it still has it.
repl.run("bar = foo * 2")
assert repl.locals is not None
assert repl.locals["bar"] == 2
def test_python_repl_no_previous_variables() -> None:
"""Test that it does not have access to variables created outside the scope."""
foo = 3 # noqa: F841
repl = PythonREPL()
output = repl.run("print(foo)")
assert output == """NameError("name 'foo' is not defined")"""
def test_python_repl_pass_in_locals() -> None:
"""Test functionality when passing in locals."""
_locals = {"foo": 4}
repl = PythonREPL(_locals=_locals)
repl.run("bar = foo * 2")
assert repl.locals is not None
assert repl.locals["bar"] == 8
def test_functionality() -> None:
"""Test correct functionality."""
chain = PythonREPL()
code = "print(1 + 1)"
output = chain.run(code)
assert output == "2\n"
def test_functionality_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
chain = PythonREPL()
tool = PythonREPLTool(python_repl=chain)
output = tool.run(_SAMPLE_CODE)
assert output == "30\n"
def test_python_ast_repl_multiline() -> None:
"""Test correct functionality for ChatGPT multiline commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE)
assert output == 30
def test_python_ast_repl_multi_statement() -> None:
"""Test correct functionality for ChatGPT multi statement commands."""
if sys.version_info < (3, 9):
pytest.skip("Python 3.9+ is required for this test")
tool = PythonAstREPLTool()
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
assert output == 30
def test_function() -> None:
"""Test correct functionality."""
chain = PythonREPL()
code = "def add(a, b): " " return a + b"
output = chain.run(code)
assert output == ""
code = "print(add(1, 2))"
output = chain.run(code)
assert output == "3\n"

View File

@ -0,0 +1,164 @@
"""Test Python REPL Tools."""
import sys
import numpy as np
import pytest
from langchain_experimental.tools.python.tool import (
PythonAstREPLTool,
PythonREPLTool,
sanitize_input,
)
def test_python_repl_tool_single_input() -> None:
"""Test that the python REPL tool works with a single input."""
tool = PythonREPLTool()
assert tool.is_single_input
assert int(tool.run("print(1 + 1)").strip()) == 2
def test_python_repl_print() -> None:
program = """
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
print("The dot product is {:d}.".format(dot_product))
"""
tool = PythonREPLTool()
assert tool.run(program) == "The dot product is 32.\n"
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_tool_single_input() -> None:
"""Test that the python REPL tool works with a single input."""
tool = PythonAstREPLTool()
assert tool.is_single_input
assert tool.run("1 + 1") == 2
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_return() -> None:
program = """
```
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
int(dot_product)
```
"""
tool = PythonAstREPLTool()
assert tool.run(program) == 32
program = """
```python
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
int(dot_product)
```
"""
assert tool.run(program) == 32
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_print() -> None:
program = """python
string = "racecar"
if string == string[::-1]:
print(string, "is a palindrome")
else:
print(string, "is not a palindrome")"""
tool = PythonAstREPLTool()
assert tool.run(program) == "racecar is a palindrome\n"
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_repl_print_python_backticks() -> None:
program = "`print('`python` is a great language.')`"
tool = PythonAstREPLTool()
assert tool.run(program) == "`python` is a great language.\n"
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_raise_exception() -> None:
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
program = """
import pandas as pd
df = pd.DataFrame(data)
df['Gender']
"""
tool = PythonAstREPLTool(locals={"data": data})
expected_outputs = (
"KeyError: 'Gender'",
"ModuleNotFoundError: No module named 'pandas'",
)
assert tool.run(program) in expected_outputs
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_print() -> None:
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
tool = PythonAstREPLTool()
assert tool.run(program) == "The square of 3 is 9.00\n"
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_return() -> None:
arr = np.array([1, 2, 3, 4, 5])
tool = PythonAstREPLTool(locals={"arr": arr})
program = "`(arr**2).sum() # Returns sum of squares`"
assert tool.run(program) == 55
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_exception() -> None:
program = "[1, 2, 3][4]"
tool = PythonAstREPLTool()
assert tool.run(program) == "IndexError: list index out of range"
def test_sanitize_input() -> None:
query = """
```
p = 5
```
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual
query = """
```python
p = 5
```
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual
query = """
p = 5
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual