mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
Add python,pandas,xorbits,spark agents to experimental (#11774)
See for contex https://github.com/langchain-ai/langchain/discussions/11680
This commit is contained in:
parent
d6e34ca2ee
commit
0d37b4c27d
@ -0,0 +1 @@
|
||||
"""Pandas toolkit."""
|
@ -0,0 +1,341 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from langchain_experimental.agents.agent_toolkits.pandas.prompt import (
|
||||
FUNCTIONS_WITH_DF,
|
||||
FUNCTIONS_WITH_MULTI_DF,
|
||||
MULTI_DF_PREFIX,
|
||||
MULTI_DF_PREFIX_FUNCTIONS,
|
||||
PREFIX,
|
||||
PREFIX_FUNCTIONS,
|
||||
SUFFIX_NO_DF,
|
||||
SUFFIX_WITH_DF,
|
||||
SUFFIX_WITH_MULTI_DF,
|
||||
)
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def _get_multi_prompt(
|
||||
dfs: List[Any],
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
num_dfs = len(dfs)
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_dfs_head = True
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = SUFFIX_WITH_MULTI_DF
|
||||
include_dfs_head = True
|
||||
else:
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
include_dfs_head = False
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad", "num_dfs"]
|
||||
if include_dfs_head:
|
||||
input_variables += ["dfs_head"]
|
||||
|
||||
if prefix is None:
|
||||
prefix = MULTI_DF_PREFIX
|
||||
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
)
|
||||
|
||||
partial_prompt = prompt.partial()
|
||||
if "dfs_head" in input_variables:
|
||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head)
|
||||
if "num_dfs" in input_variables:
|
||||
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs))
|
||||
return partial_prompt, tools
|
||||
|
||||
|
||||
def _get_single_prompt(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_df_head = True
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = SUFFIX_WITH_DF
|
||||
include_df_head = True
|
||||
else:
|
||||
suffix_to_use = SUFFIX_NO_DF
|
||||
include_df_head = False
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
if include_df_head:
|
||||
input_variables += ["df_head"]
|
||||
|
||||
if prefix is None:
|
||||
prefix = PREFIX
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
)
|
||||
|
||||
partial_prompt = prompt.partial()
|
||||
if "df_head" in input_variables:
|
||||
partial_prompt = partial_prompt.partial(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
return partial_prompt, tools
|
||||
|
||||
|
||||
def _get_prompt_and_tools(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
|
||||
if include_df_in_prompt is not None and suffix is not None:
|
||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||
|
||||
if isinstance(df, list):
|
||||
for item in df:
|
||||
if not isinstance(item, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_multi_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_single_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
|
||||
|
||||
def _get_functions_single_prompt(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
suffix_to_use = suffix_to_use.format(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
elif include_df_in_prompt:
|
||||
suffix_to_use = FUNCTIONS_WITH_DF.format(
|
||||
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||
)
|
||||
else:
|
||||
suffix_to_use = ""
|
||||
|
||||
if prefix is None:
|
||||
prefix = PREFIX_FUNCTIONS
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
|
||||
|
||||
def _get_functions_multi_prompt(
|
||||
dfs: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
dfs_head = "\n\n".join(
|
||||
[d.head(number_of_head_rows).to_markdown() for d in dfs]
|
||||
)
|
||||
suffix_to_use = suffix_to_use.format(
|
||||
dfs_head=dfs_head,
|
||||
)
|
||||
elif include_df_in_prompt:
|
||||
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
|
||||
dfs_head=dfs_head,
|
||||
)
|
||||
else:
|
||||
suffix_to_use = ""
|
||||
|
||||
if prefix is None:
|
||||
prefix = MULTI_DF_PREFIX_FUNCTIONS
|
||||
prefix = prefix.format(num_dfs=str(len(dfs)))
|
||||
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
|
||||
|
||||
def _get_functions_prompt_and_tools(
|
||||
df: Any,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"pandas package not found, please install with `pip install pandas`"
|
||||
)
|
||||
if input_variables is not None:
|
||||
raise ValueError("`input_variables` is not supported at the moment.")
|
||||
|
||||
if include_df_in_prompt is not None and suffix is not None:
|
||||
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||
|
||||
if isinstance(df, list):
|
||||
for item in df:
|
||||
if not isinstance(item, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_multi_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_single_prompt(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
|
||||
|
||||
def create_pandas_dataframe_agent(
|
||||
llm: BaseLanguageModel,
|
||||
df: Any,
|
||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a pandas agent from an LLM and dataframe."""
|
||||
agent: BaseSingleActionAgent
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt, base_tools = _get_prompt_and_tools(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
_prompt, base_tools = _get_functions_prompt_and_tools(
|
||||
df,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -0,0 +1,44 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """
|
||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
MULTI_DF_PREFIX = """
|
||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
|
||||
should use the tools below to answer the question posed of you:"""
|
||||
|
||||
SUFFIX_NO_DF = """
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SUFFIX_WITH_DF = """
|
||||
This is the result of `print(df.head())`:
|
||||
{df_head}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
SUFFIX_WITH_MULTI_DF = """
|
||||
This is the result of `print(df.head())` for each dataframe:
|
||||
{dfs_head}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
PREFIX_FUNCTIONS = """
|
||||
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
|
||||
|
||||
MULTI_DF_PREFIX_FUNCTIONS = """
|
||||
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
|
||||
|
||||
FUNCTIONS_WITH_DF = """
|
||||
This is the result of `print(df.head())`:
|
||||
{df_head}"""
|
||||
|
||||
FUNCTIONS_WITH_MULTI_DF = """
|
||||
This is the result of `print(df.head())` for each dataframe:
|
||||
{dfs_head}"""
|
@ -0,0 +1,59 @@
|
||||
"""Python agent."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||
from langchain.agents.types import AgentType
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from langchain_experimental.agents.agent_toolkits.python.prompt import PREFIX
|
||||
from langchain_experimental.tools.python.tool import PythonREPLTool
|
||||
|
||||
|
||||
def create_python_agent(
|
||||
llm: BaseLanguageModel,
|
||||
tool: PythonREPLTool,
|
||||
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
verbose: bool = False,
|
||||
prefix: str = PREFIX,
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a python agent from an LLM and tool."""
|
||||
tools = [tool]
|
||||
agent: BaseSingleActionAgent
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
system_message = SystemMessage(content=prefix)
|
||||
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -0,0 +1,9 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """You are an agent designed to write and execute python code to answer questions.
|
||||
You have access to a python REPL, which you can use to execute python code.
|
||||
If you get an error, debug your code and try again.
|
||||
Only use the output of your code to answer the question.
|
||||
You might know the answer without running any code, but you should still run the code to get the answer.
|
||||
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
|
||||
"""
|
@ -0,0 +1 @@
|
||||
"""spark toolkit"""
|
@ -0,0 +1,81 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def _validate_spark_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql import DataFrame as SparkLocalDataFrame
|
||||
|
||||
return isinstance(df, SparkLocalDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _validate_spark_connect_df(df: Any) -> bool:
|
||||
try:
|
||||
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
|
||||
|
||||
return isinstance(df, SparkConnectDataFrame)
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def create_spark_dataframe_agent(
|
||||
llm: BaseLLM,
|
||||
df: Any,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a Spark agent from an LLM and dataframe."""
|
||||
|
||||
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
|
||||
raise ImportError("Spark is not installed. run `pip install pyspark`.")
|
||||
|
||||
if input_variables is None:
|
||||
input_variables = ["df", "input", "agent_scratchpad"]
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
|
||||
)
|
||||
partial_prompt = prompt.partial(df=str(df.first()))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -0,0 +1,13 @@
|
||||
# flake8: noqa
|
||||
|
||||
PREFIX = """
|
||||
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
SUFFIX = """
|
||||
This is the result of `print(df.first())`:
|
||||
{df}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
@ -0,0 +1 @@
|
||||
"""Xorbits toolkit."""
|
@ -0,0 +1,91 @@
|
||||
"""Agent for working with xorbits objects."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from langchain_experimental.agents.agent_toolkits.xorbits.prompt import (
|
||||
NP_PREFIX,
|
||||
NP_SUFFIX,
|
||||
PD_PREFIX,
|
||||
PD_SUFFIX,
|
||||
)
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def create_xorbits_agent(
|
||||
llm: BaseLLM,
|
||||
data: Any,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
input_variables: Optional[List[str]] = None,
|
||||
verbose: bool = False,
|
||||
return_intermediate_steps: bool = False,
|
||||
max_iterations: Optional[int] = 15,
|
||||
max_execution_time: Optional[float] = None,
|
||||
early_stopping_method: str = "force",
|
||||
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a xorbits agent from an LLM and dataframe."""
|
||||
try:
|
||||
from xorbits import numpy as np
|
||||
from xorbits import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Xorbits package not installed, please install with `pip install xorbits`"
|
||||
)
|
||||
|
||||
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
||||
raise ValueError(
|
||||
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
|
||||
)
|
||||
if input_variables is None:
|
||||
input_variables = ["data", "input", "agent_scratchpad"]
|
||||
tools = [PythonAstREPLTool(locals={"data": data})]
|
||||
prompt, partial_input = None, None
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=PD_PREFIX if prefix == "" else prefix,
|
||||
suffix=PD_SUFFIX if suffix == "" else suffix,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
partial_input = str(data.head())
|
||||
else:
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=NP_PREFIX if prefix == "" else prefix,
|
||||
suffix=NP_SUFFIX if suffix == "" else suffix,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
partial_input = str(data[: len(data) // 2])
|
||||
partial_prompt = prompt.partial(data=partial_input)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
verbose=verbose,
|
||||
return_intermediate_steps=return_intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
max_execution_time=max_execution_time,
|
||||
early_stopping_method=early_stopping_method,
|
||||
**(agent_executor_kwargs or {}),
|
||||
)
|
@ -0,0 +1,33 @@
|
||||
PD_PREFIX = """
|
||||
You are working with Xorbits dataframe object in Python.
|
||||
Before importing Numpy or Pandas in the current script,
|
||||
remember to import the xorbits version of the library instead.
|
||||
To import the xorbits version of Numpy, replace the original import statement
|
||||
`import pandas as pd` with `import xorbits.pandas as pd`.
|
||||
The name of the input is `data`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
PD_SUFFIX = """
|
||||
This is the result of `print(data)`:
|
||||
{data}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
||||
|
||||
NP_PREFIX = """
|
||||
You are working with Xorbits ndarray object in Python.
|
||||
Before importing Numpy in the current script,
|
||||
remember to import the xorbits version of the library instead.
|
||||
To import the xorbits version of Numpy, replace the original import statement
|
||||
`import numpy as np` with `import xorbits.numpy as np`.
|
||||
The name of the input is `data`.
|
||||
You should use the tools below to answer the question posed of you:"""
|
||||
|
||||
NP_SUFFIX = """
|
||||
This is the result of `print(data)`:
|
||||
{data}
|
||||
|
||||
Begin!
|
||||
Question: {input}
|
||||
{agent_scratchpad}"""
|
150
libs/experimental/langchain_experimental/tools/python/tool.py
Normal file
150
libs/experimental/langchain_experimental/tools/python/tool.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
from langchain_experimental.utilities.python import PythonREPL
|
||||
|
||||
|
||||
def _get_default_python_repl() -> PythonREPL:
|
||||
return PythonREPL(_globals=globals(), _locals=None)
|
||||
|
||||
|
||||
def sanitize_input(query: str) -> str:
|
||||
"""Sanitize input to the python REPL.
|
||||
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||
|
||||
Args:
|
||||
query: The query to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized query
|
||||
"""
|
||||
|
||||
# Removes `, whitespace & python from start
|
||||
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||
# Removes whitespace & ` from end
|
||||
query = re.sub(r"(\s|`)*$", "", query)
|
||||
return query
|
||||
|
||||
|
||||
class PythonREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "Python_REPL"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"If you want to see the output of a value, you should print it out "
|
||||
"with `print(...)`."
|
||||
)
|
||||
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||
sanitize_input: bool = True
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool."""
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
return self.python_repl.run(query)
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self.run, query)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
query: str = Field(description="code snippet to run")
|
||||
|
||||
|
||||
class PythonAstREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
name: str = "python_repl_ast"
|
||||
description: str = (
|
||||
"A Python shell. Use this to execute python commands. "
|
||||
"Input should be a valid python command. "
|
||||
"When using this tool, sometimes output is abbreviated - "
|
||||
"make sure it does not look abbreviated before using it in your answer."
|
||||
)
|
||||
globals: Optional[Dict] = Field(default_factory=dict)
|
||||
locals: Optional[Dict] = Field(default_factory=dict)
|
||||
sanitize_input: bool = True
|
||||
args_schema: Type[BaseModel] = PythonInputs
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_python_version(cls, values: Dict) -> Dict:
|
||||
"""Validate valid python version."""
|
||||
if sys.version_info < (3, 9):
|
||||
raise ValueError(
|
||||
"This tool relies on Python 3.9 or higher "
|
||||
"(as it uses new functionality in the `ast` module, "
|
||||
f"you have Python version: {sys.version}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
if self.sanitize_input:
|
||||
query = sanitize_input(query)
|
||||
tree = ast.parse(query)
|
||||
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||
module_end_str = ast.unparse(module_end) # type: ignore
|
||||
io_buffer = StringIO()
|
||||
try:
|
||||
with redirect_stdout(io_buffer):
|
||||
ret = eval(module_end_str, self.globals, self.locals)
|
||||
if ret is None:
|
||||
return io_buffer.getvalue()
|
||||
else:
|
||||
return ret
|
||||
except Exception:
|
||||
with redirect_stdout(io_buffer):
|
||||
exec(module_end_str, self.globals, self.locals)
|
||||
return io_buffer.getvalue()
|
||||
except Exception as e:
|
||||
return "{}: {}".format(type(e).__name__, str(e))
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
"""Use the tool asynchronously."""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(None, self._run, query)
|
||||
|
||||
return result
|
71
libs/experimental/langchain_experimental/utilities/python.py
Normal file
71
libs/experimental/langchain_experimental/utilities/python.py
Normal file
@ -0,0 +1,71 @@
|
||||
import functools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, Optional
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def warn_once() -> None:
|
||||
"""Warn once about the dangers of PythonREPL."""
|
||||
logger.warning("Python REPL can execute arbitrary code. Use with caution.")
|
||||
|
||||
|
||||
class PythonREPL(BaseModel):
|
||||
"""Simulates a standalone Python REPL."""
|
||||
|
||||
globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
|
||||
locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
|
||||
|
||||
@classmethod
|
||||
def worker(
|
||||
cls,
|
||||
command: str,
|
||||
globals: Optional[Dict],
|
||||
locals: Optional[Dict],
|
||||
queue: multiprocessing.Queue,
|
||||
) -> None:
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
try:
|
||||
exec(command, globals, locals)
|
||||
sys.stdout = old_stdout
|
||||
queue.put(mystdout.getvalue())
|
||||
except Exception as e:
|
||||
sys.stdout = old_stdout
|
||||
queue.put(repr(e))
|
||||
|
||||
def run(self, command: str, timeout: Optional[int] = None) -> str:
|
||||
"""Run command with own globals/locals and returns anything printed.
|
||||
Timeout after the specified number of seconds."""
|
||||
|
||||
# Warn against dangers of PythonREPL
|
||||
warn_once()
|
||||
|
||||
queue: multiprocessing.Queue = multiprocessing.Queue()
|
||||
|
||||
# Only use multiprocessing if we are enforcing a timeout
|
||||
if timeout is not None:
|
||||
# create a Process
|
||||
p = multiprocessing.Process(
|
||||
target=self.worker, args=(command, self.globals, self.locals, queue)
|
||||
)
|
||||
|
||||
# start it
|
||||
p.start()
|
||||
|
||||
# wait for the process to finish or kill it after timeout seconds
|
||||
p.join(timeout)
|
||||
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
return "Execution timed out"
|
||||
else:
|
||||
self.worker(command, self.globals, self.locals, queue)
|
||||
# get the result from the worker function
|
||||
return queue.get()
|
112
libs/experimental/tests/unit_tests/python/test_python_1.py
Normal file
112
libs/experimental/tests/unit_tests/python/test_python_1.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Test functionality of Python REPL."""
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||
from langchain_experimental.utilities.python import PythonREPL
|
||||
|
||||
_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
print(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE = """
|
||||
```
|
||||
def multiply():
|
||||
return(5*6)
|
||||
multiply()
|
||||
```
|
||||
"""
|
||||
|
||||
_AST_SAMPLE_CODE_EXECUTE = """
|
||||
```
|
||||
def multiply(a, b):
|
||||
return(5*6)
|
||||
a = 5
|
||||
b = 6
|
||||
|
||||
multiply(a, b)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def test_python_repl() -> None:
|
||||
"""Test functionality when globals/locals are not provided."""
|
||||
repl = PythonREPL()
|
||||
|
||||
# Run a simple initial command.
|
||||
repl.run("foo = 1")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["foo"] == 1
|
||||
|
||||
# Now run a command that accesses `foo` to make sure it still has it.
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 2
|
||||
|
||||
|
||||
def test_python_repl_no_previous_variables() -> None:
|
||||
"""Test that it does not have access to variables created outside the scope."""
|
||||
foo = 3 # noqa: F841
|
||||
repl = PythonREPL()
|
||||
output = repl.run("print(foo)")
|
||||
assert output == """NameError("name 'foo' is not defined")"""
|
||||
|
||||
|
||||
def test_python_repl_pass_in_locals() -> None:
|
||||
"""Test functionality when passing in locals."""
|
||||
_locals = {"foo": 4}
|
||||
repl = PythonREPL(_locals=_locals)
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl.locals is not None
|
||||
assert repl.locals["bar"] == 8
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "print(1 + 1)"
|
||||
output = chain.run(code)
|
||||
assert output == "2\n"
|
||||
|
||||
|
||||
def test_functionality_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
chain = PythonREPL()
|
||||
tool = PythonREPLTool(python_repl=chain)
|
||||
output = tool.run(_SAMPLE_CODE)
|
||||
assert output == "30\n"
|
||||
|
||||
|
||||
def test_python_ast_repl_multiline() -> None:
|
||||
"""Test correct functionality for ChatGPT multiline commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_python_ast_repl_multi_statement() -> None:
|
||||
"""Test correct functionality for ChatGPT multi statement commands."""
|
||||
if sys.version_info < (3, 9):
|
||||
pytest.skip("Python 3.9+ is required for this test")
|
||||
tool = PythonAstREPLTool()
|
||||
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
|
||||
assert output == 30
|
||||
|
||||
|
||||
def test_function() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "def add(a, b): " " return a + b"
|
||||
output = chain.run(code)
|
||||
assert output == ""
|
||||
|
||||
code = "print(add(1, 2))"
|
||||
output = chain.run(code)
|
||||
assert output == "3\n"
|
164
libs/experimental/tests/unit_tests/python/test_python_2.py
Normal file
164
libs/experimental/tests/unit_tests/python/test_python_2.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""Test Python REPL Tools."""
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain_experimental.tools.python.tool import (
|
||||
PythonAstREPLTool,
|
||||
PythonREPLTool,
|
||||
sanitize_input,
|
||||
)
|
||||
|
||||
|
||||
def test_python_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||
|
||||
|
||||
def test_python_repl_print() -> None:
|
||||
program = """
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
print("The dot product is {:d}.".format(dot_product))
|
||||
"""
|
||||
tool = PythonREPLTool()
|
||||
assert tool.run(program) == "The dot product is 32.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_tool_single_input() -> None:
|
||||
"""Test that the python REPL tool works with a single input."""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.is_single_input
|
||||
assert tool.run("1 + 1") == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_return() -> None:
|
||||
program = """
|
||||
```
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == 32
|
||||
|
||||
program = """
|
||||
```python
|
||||
import numpy as np
|
||||
v1 = np.array([1, 2, 3])
|
||||
v2 = np.array([4, 5, 6])
|
||||
dot_product = np.dot(v1, v2)
|
||||
int(dot_product)
|
||||
```
|
||||
"""
|
||||
assert tool.run(program) == 32
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_print() -> None:
|
||||
program = """python
|
||||
string = "racecar"
|
||||
if string == string[::-1]:
|
||||
print(string, "is a palindrome")
|
||||
else:
|
||||
print(string, "is not a palindrome")"""
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "racecar is a palindrome\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_repl_print_python_backticks() -> None:
|
||||
program = "`print('`python` is a great language.')`"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "`python` is a great language.\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_raise_exception() -> None:
|
||||
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
||||
program = """
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(data)
|
||||
df['Gender']
|
||||
"""
|
||||
tool = PythonAstREPLTool(locals={"data": data})
|
||||
expected_outputs = (
|
||||
"KeyError: 'Gender'",
|
||||
"ModuleNotFoundError: No module named 'pandas'",
|
||||
)
|
||||
assert tool.run(program) in expected_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_print() -> None:
|
||||
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "The square of 3 is 9.00\n"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_return() -> None:
|
||||
arr = np.array([1, 2, 3, 4, 5])
|
||||
tool = PythonAstREPLTool(locals={"arr": arr})
|
||||
program = "`(arr**2).sum() # Returns sum of squares`"
|
||||
assert tool.run(program) == 55
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
def test_python_ast_repl_one_line_exception() -> None:
|
||||
program = "[1, 2, 3][4]"
|
||||
tool = PythonAstREPLTool()
|
||||
assert tool.run(program) == "IndexError: list index out of range"
|
||||
|
||||
|
||||
def test_sanitize_input() -> None:
|
||||
query = """
|
||||
```
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
```python
|
||||
p = 5
|
||||
```
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
||||
|
||||
query = """
|
||||
p = 5
|
||||
"""
|
||||
expected = "p = 5"
|
||||
actual = sanitize_input(query)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue
Block a user