mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
Add python,pandas,xorbits,spark agents to experimental (#11774)
See for contex https://github.com/langchain-ai/langchain/discussions/11680
This commit is contained in:
parent
d6e34ca2ee
commit
0d37b4c27d
@ -0,0 +1 @@
|
|||||||
|
"""Pandas toolkit."""
|
@ -0,0 +1,341 @@
|
|||||||
|
"""Agent for working with pandas objects."""
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||||
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
|
from langchain.agents.types import AgentType
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.schema import BasePromptTemplate
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.schema.messages import SystemMessage
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
|
||||||
|
from langchain_experimental.agents.agent_toolkits.pandas.prompt import (
|
||||||
|
FUNCTIONS_WITH_DF,
|
||||||
|
FUNCTIONS_WITH_MULTI_DF,
|
||||||
|
MULTI_DF_PREFIX,
|
||||||
|
MULTI_DF_PREFIX_FUNCTIONS,
|
||||||
|
PREFIX,
|
||||||
|
PREFIX_FUNCTIONS,
|
||||||
|
SUFFIX_NO_DF,
|
||||||
|
SUFFIX_WITH_DF,
|
||||||
|
SUFFIX_WITH_MULTI_DF,
|
||||||
|
)
|
||||||
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||||
|
|
||||||
|
|
||||||
|
def _get_multi_prompt(
|
||||||
|
dfs: List[Any],
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
num_dfs = len(dfs)
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
include_dfs_head = True
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
suffix_to_use = SUFFIX_WITH_MULTI_DF
|
||||||
|
include_dfs_head = True
|
||||||
|
else:
|
||||||
|
suffix_to_use = SUFFIX_NO_DF
|
||||||
|
include_dfs_head = False
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["input", "agent_scratchpad", "num_dfs"]
|
||||||
|
if include_dfs_head:
|
||||||
|
input_variables += ["dfs_head"]
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = MULTI_DF_PREFIX
|
||||||
|
|
||||||
|
df_locals = {}
|
||||||
|
for i, dataframe in enumerate(dfs):
|
||||||
|
df_locals[f"df{i + 1}"] = dataframe
|
||||||
|
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||||
|
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||||
|
)
|
||||||
|
|
||||||
|
partial_prompt = prompt.partial()
|
||||||
|
if "dfs_head" in input_variables:
|
||||||
|
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||||
|
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head)
|
||||||
|
if "num_dfs" in input_variables:
|
||||||
|
partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs))
|
||||||
|
return partial_prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_single_prompt(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
include_df_head = True
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
suffix_to_use = SUFFIX_WITH_DF
|
||||||
|
include_df_head = True
|
||||||
|
else:
|
||||||
|
suffix_to_use = SUFFIX_NO_DF
|
||||||
|
include_df_head = False
|
||||||
|
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["input", "agent_scratchpad"]
|
||||||
|
if include_df_head:
|
||||||
|
input_variables += ["df_head"]
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = PREFIX
|
||||||
|
|
||||||
|
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||||
|
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||||
|
)
|
||||||
|
|
||||||
|
partial_prompt = prompt.partial()
|
||||||
|
if "df_head" in input_variables:
|
||||||
|
partial_prompt = partial_prompt.partial(
|
||||||
|
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||||
|
)
|
||||||
|
return partial_prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_prompt_and_tools(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
pd.set_option("display.max_columns", None)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pandas package not found, please install with `pip install pandas`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if include_df_in_prompt is not None and suffix is not None:
|
||||||
|
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||||
|
|
||||||
|
if isinstance(df, list):
|
||||||
|
for item in df:
|
||||||
|
if not isinstance(item, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_multi_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(df, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_single_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_single_prompt(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
if include_df_in_prompt:
|
||||||
|
suffix_to_use = suffix_to_use.format(
|
||||||
|
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||||
|
)
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
suffix_to_use = FUNCTIONS_WITH_DF.format(
|
||||||
|
df_head=str(df.head(number_of_head_rows).to_markdown())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suffix_to_use = ""
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = PREFIX_FUNCTIONS
|
||||||
|
|
||||||
|
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||||
|
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||||
|
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
return prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_multi_prompt(
|
||||||
|
dfs: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
if suffix is not None:
|
||||||
|
suffix_to_use = suffix
|
||||||
|
if include_df_in_prompt:
|
||||||
|
dfs_head = "\n\n".join(
|
||||||
|
[d.head(number_of_head_rows).to_markdown() for d in dfs]
|
||||||
|
)
|
||||||
|
suffix_to_use = suffix_to_use.format(
|
||||||
|
dfs_head=dfs_head,
|
||||||
|
)
|
||||||
|
elif include_df_in_prompt:
|
||||||
|
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
|
||||||
|
suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format(
|
||||||
|
dfs_head=dfs_head,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
suffix_to_use = ""
|
||||||
|
|
||||||
|
if prefix is None:
|
||||||
|
prefix = MULTI_DF_PREFIX_FUNCTIONS
|
||||||
|
prefix = prefix.format(num_dfs=str(len(dfs)))
|
||||||
|
|
||||||
|
df_locals = {}
|
||||||
|
for i, dataframe in enumerate(dfs):
|
||||||
|
df_locals[f"df{i + 1}"] = dataframe
|
||||||
|
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||||
|
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||||
|
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
return prompt, tools
|
||||||
|
|
||||||
|
|
||||||
|
def _get_functions_prompt_and_tools(
|
||||||
|
df: Any,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
pd.set_option("display.max_columns", None)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"pandas package not found, please install with `pip install pandas`"
|
||||||
|
)
|
||||||
|
if input_variables is not None:
|
||||||
|
raise ValueError("`input_variables` is not supported at the moment.")
|
||||||
|
|
||||||
|
if include_df_in_prompt is not None and suffix is not None:
|
||||||
|
raise ValueError("If suffix is specified, include_df_in_prompt should not be.")
|
||||||
|
|
||||||
|
if isinstance(df, list):
|
||||||
|
for item in df:
|
||||||
|
if not isinstance(item, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_functions_multi_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(df, pd.DataFrame):
|
||||||
|
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||||
|
return _get_functions_single_prompt(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pandas_dataframe_agent(
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
df: Any,
|
||||||
|
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
suffix: Optional[str] = None,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
return_intermediate_steps: bool = False,
|
||||||
|
max_iterations: Optional[int] = 15,
|
||||||
|
max_execution_time: Optional[float] = None,
|
||||||
|
early_stopping_method: str = "force",
|
||||||
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
include_df_in_prompt: Optional[bool] = True,
|
||||||
|
number_of_head_rows: int = 5,
|
||||||
|
extra_tools: Sequence[BaseTool] = (),
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> AgentExecutor:
|
||||||
|
"""Construct a pandas agent from an LLM and dataframe."""
|
||||||
|
agent: BaseSingleActionAgent
|
||||||
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
|
prompt, base_tools = _get_prompt_and_tools(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
tools = base_tools + list(extra_tools)
|
||||||
|
llm_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
allowed_tools=tool_names,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
|
_prompt, base_tools = _get_functions_prompt_and_tools(
|
||||||
|
df,
|
||||||
|
prefix=prefix,
|
||||||
|
suffix=suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
include_df_in_prompt=include_df_in_prompt,
|
||||||
|
number_of_head_rows=number_of_head_rows,
|
||||||
|
)
|
||||||
|
tools = base_tools + list(extra_tools)
|
||||||
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
prompt=_prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||||
|
return AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
|
return_intermediate_steps=return_intermediate_steps,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
max_execution_time=max_execution_time,
|
||||||
|
early_stopping_method=early_stopping_method,
|
||||||
|
**(agent_executor_kwargs or {}),
|
||||||
|
)
|
@ -0,0 +1,44 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
PREFIX = """
|
||||||
|
You are working with a pandas dataframe in Python. The name of the dataframe is `df`.
|
||||||
|
You should use the tools below to answer the question posed of you:"""
|
||||||
|
|
||||||
|
MULTI_DF_PREFIX = """
|
||||||
|
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You
|
||||||
|
should use the tools below to answer the question posed of you:"""
|
||||||
|
|
||||||
|
SUFFIX_NO_DF = """
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
SUFFIX_WITH_DF = """
|
||||||
|
This is the result of `print(df.head())`:
|
||||||
|
{df_head}
|
||||||
|
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
SUFFIX_WITH_MULTI_DF = """
|
||||||
|
This is the result of `print(df.head())` for each dataframe:
|
||||||
|
{dfs_head}
|
||||||
|
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
PREFIX_FUNCTIONS = """
|
||||||
|
You are working with a pandas dataframe in Python. The name of the dataframe is `df`."""
|
||||||
|
|
||||||
|
MULTI_DF_PREFIX_FUNCTIONS = """
|
||||||
|
You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc."""
|
||||||
|
|
||||||
|
FUNCTIONS_WITH_DF = """
|
||||||
|
This is the result of `print(df.head())`:
|
||||||
|
{df_head}"""
|
||||||
|
|
||||||
|
FUNCTIONS_WITH_MULTI_DF = """
|
||||||
|
This is the result of `print(df.head())` for each dataframe:
|
||||||
|
{dfs_head}"""
|
@ -0,0 +1,59 @@
|
|||||||
|
"""Python agent."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||||
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
||||||
|
from langchain.agents.types import AgentType
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
|
from langchain_experimental.agents.agent_toolkits.python.prompt import PREFIX
|
||||||
|
from langchain_experimental.tools.python.tool import PythonREPLTool
|
||||||
|
|
||||||
|
|
||||||
|
def create_python_agent(
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
tool: PythonREPLTool,
|
||||||
|
agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> AgentExecutor:
|
||||||
|
"""Construct a python agent from an LLM and tool."""
|
||||||
|
tools = [tool]
|
||||||
|
agent: BaseSingleActionAgent
|
||||||
|
|
||||||
|
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||||
|
prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix)
|
||||||
|
llm_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
|
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||||
|
system_message = SystemMessage(content=prefix)
|
||||||
|
_prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||||
|
agent = OpenAIFunctionsAgent(
|
||||||
|
llm=llm,
|
||||||
|
prompt=_prompt,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Agent type {agent_type} not supported at the moment.")
|
||||||
|
return AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
|
**(agent_executor_kwargs or {}),
|
||||||
|
)
|
@ -0,0 +1,9 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
PREFIX = """You are an agent designed to write and execute python code to answer questions.
|
||||||
|
You have access to a python REPL, which you can use to execute python code.
|
||||||
|
If you get an error, debug your code and try again.
|
||||||
|
Only use the output of your code to answer the question.
|
||||||
|
You might know the answer without running any code, but you should still run the code to get the answer.
|
||||||
|
If it does not seem like you can write code to answer the question, just return "I don't know" as the answer.
|
||||||
|
"""
|
@ -0,0 +1 @@
|
|||||||
|
"""spark toolkit"""
|
@ -0,0 +1,81 @@
|
|||||||
|
"""Agent for working with pandas objects."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.agents.agent import AgentExecutor
|
||||||
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX
|
||||||
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_spark_df(df: Any) -> bool:
|
||||||
|
try:
|
||||||
|
from pyspark.sql import DataFrame as SparkLocalDataFrame
|
||||||
|
|
||||||
|
return isinstance(df, SparkLocalDataFrame)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_spark_connect_df(df: Any) -> bool:
|
||||||
|
try:
|
||||||
|
from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame
|
||||||
|
|
||||||
|
return isinstance(df, SparkConnectDataFrame)
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_spark_dataframe_agent(
|
||||||
|
llm: BaseLLM,
|
||||||
|
df: Any,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
prefix: str = PREFIX,
|
||||||
|
suffix: str = SUFFIX,
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
return_intermediate_steps: bool = False,
|
||||||
|
max_iterations: Optional[int] = 15,
|
||||||
|
max_execution_time: Optional[float] = None,
|
||||||
|
early_stopping_method: str = "force",
|
||||||
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> AgentExecutor:
|
||||||
|
"""Construct a Spark agent from an LLM and dataframe."""
|
||||||
|
|
||||||
|
if not _validate_spark_df(df) and not _validate_spark_connect_df(df):
|
||||||
|
raise ImportError("Spark is not installed. run `pip install pyspark`.")
|
||||||
|
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["df", "input", "agent_scratchpad"]
|
||||||
|
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools, prefix=prefix, suffix=suffix, input_variables=input_variables
|
||||||
|
)
|
||||||
|
partial_prompt = prompt.partial(df=str(df.first()))
|
||||||
|
llm_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=partial_prompt,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
allowed_tools=tool_names,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
|
return_intermediate_steps=return_intermediate_steps,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
max_execution_time=max_execution_time,
|
||||||
|
early_stopping_method=early_stopping_method,
|
||||||
|
**(agent_executor_kwargs or {}),
|
||||||
|
)
|
@ -0,0 +1,13 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
PREFIX = """
|
||||||
|
You are working with a spark dataframe in Python. The name of the dataframe is `df`.
|
||||||
|
You should use the tools below to answer the question posed of you:"""
|
||||||
|
|
||||||
|
SUFFIX = """
|
||||||
|
This is the result of `print(df.first())`:
|
||||||
|
{df}
|
||||||
|
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
@ -0,0 +1 @@
|
|||||||
|
"""Xorbits toolkit."""
|
@ -0,0 +1,91 @@
|
|||||||
|
"""Agent for working with xorbits objects."""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.agents.agent import AgentExecutor
|
||||||
|
from langchain.agents.mrkl.base import ZeroShotAgent
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
|
||||||
|
from langchain_experimental.agents.agent_toolkits.xorbits.prompt import (
|
||||||
|
NP_PREFIX,
|
||||||
|
NP_SUFFIX,
|
||||||
|
PD_PREFIX,
|
||||||
|
PD_SUFFIX,
|
||||||
|
)
|
||||||
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool
|
||||||
|
|
||||||
|
|
||||||
|
def create_xorbits_agent(
|
||||||
|
llm: BaseLLM,
|
||||||
|
data: Any,
|
||||||
|
callback_manager: Optional[BaseCallbackManager] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
suffix: str = "",
|
||||||
|
input_variables: Optional[List[str]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
return_intermediate_steps: bool = False,
|
||||||
|
max_iterations: Optional[int] = 15,
|
||||||
|
max_execution_time: Optional[float] = None,
|
||||||
|
early_stopping_method: str = "force",
|
||||||
|
agent_executor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> AgentExecutor:
|
||||||
|
"""Construct a xorbits agent from an LLM and dataframe."""
|
||||||
|
try:
|
||||||
|
from xorbits import numpy as np
|
||||||
|
from xorbits import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Xorbits package not installed, please install with `pip install xorbits`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(data, (pd.DataFrame, np.ndarray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected Xorbits DataFrame or ndarray object, got {type(data)}"
|
||||||
|
)
|
||||||
|
if input_variables is None:
|
||||||
|
input_variables = ["data", "input", "agent_scratchpad"]
|
||||||
|
tools = [PythonAstREPLTool(locals={"data": data})]
|
||||||
|
prompt, partial_input = None, None
|
||||||
|
|
||||||
|
if isinstance(data, pd.DataFrame):
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=PD_PREFIX if prefix == "" else prefix,
|
||||||
|
suffix=PD_SUFFIX if suffix == "" else suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
)
|
||||||
|
partial_input = str(data.head())
|
||||||
|
else:
|
||||||
|
prompt = ZeroShotAgent.create_prompt(
|
||||||
|
tools,
|
||||||
|
prefix=NP_PREFIX if prefix == "" else prefix,
|
||||||
|
suffix=NP_SUFFIX if suffix == "" else suffix,
|
||||||
|
input_variables=input_variables,
|
||||||
|
)
|
||||||
|
partial_input = str(data[: len(data) // 2])
|
||||||
|
partial_prompt = prompt.partial(data=partial_input)
|
||||||
|
llm_chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=partial_prompt,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
)
|
||||||
|
tool_names = [tool.name for tool in tools]
|
||||||
|
agent = ZeroShotAgent(
|
||||||
|
llm_chain=llm_chain,
|
||||||
|
allowed_tools=tool_names,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return AgentExecutor.from_agent_and_tools(
|
||||||
|
agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=verbose,
|
||||||
|
return_intermediate_steps=return_intermediate_steps,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
max_execution_time=max_execution_time,
|
||||||
|
early_stopping_method=early_stopping_method,
|
||||||
|
**(agent_executor_kwargs or {}),
|
||||||
|
)
|
@ -0,0 +1,33 @@
|
|||||||
|
PD_PREFIX = """
|
||||||
|
You are working with Xorbits dataframe object in Python.
|
||||||
|
Before importing Numpy or Pandas in the current script,
|
||||||
|
remember to import the xorbits version of the library instead.
|
||||||
|
To import the xorbits version of Numpy, replace the original import statement
|
||||||
|
`import pandas as pd` with `import xorbits.pandas as pd`.
|
||||||
|
The name of the input is `data`.
|
||||||
|
You should use the tools below to answer the question posed of you:"""
|
||||||
|
|
||||||
|
PD_SUFFIX = """
|
||||||
|
This is the result of `print(data)`:
|
||||||
|
{data}
|
||||||
|
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
NP_PREFIX = """
|
||||||
|
You are working with Xorbits ndarray object in Python.
|
||||||
|
Before importing Numpy in the current script,
|
||||||
|
remember to import the xorbits version of the library instead.
|
||||||
|
To import the xorbits version of Numpy, replace the original import statement
|
||||||
|
`import numpy as np` with `import xorbits.numpy as np`.
|
||||||
|
The name of the input is `data`.
|
||||||
|
You should use the tools below to answer the question posed of you:"""
|
||||||
|
|
||||||
|
NP_SUFFIX = """
|
||||||
|
This is the result of `print(data)`:
|
||||||
|
{data}
|
||||||
|
|
||||||
|
Begin!
|
||||||
|
Question: {input}
|
||||||
|
{agent_scratchpad}"""
|
150
libs/experimental/langchain_experimental/tools/python/tool.py
Normal file
150
libs/experimental/langchain_experimental/tools/python/tool.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
from io import StringIO
|
||||||
|
from typing import Any, Dict, Optional, Type
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForToolRun,
|
||||||
|
CallbackManagerForToolRun,
|
||||||
|
)
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
|
from langchain.tools.base import BaseTool
|
||||||
|
|
||||||
|
from langchain_experimental.utilities.python import PythonREPL
|
||||||
|
|
||||||
|
|
||||||
|
def _get_default_python_repl() -> PythonREPL:
|
||||||
|
return PythonREPL(_globals=globals(), _locals=None)
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_input(query: str) -> str:
|
||||||
|
"""Sanitize input to the python REPL.
|
||||||
|
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The sanitized query
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Removes `, whitespace & python from start
|
||||||
|
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
|
||||||
|
# Removes whitespace & ` from end
|
||||||
|
query = re.sub(r"(\s|`)*$", "", query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
class PythonREPLTool(BaseTool):
|
||||||
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
|
name: str = "Python_REPL"
|
||||||
|
description: str = (
|
||||||
|
"A Python shell. Use this to execute python commands. "
|
||||||
|
"Input should be a valid python command. "
|
||||||
|
"If you want to see the output of a value, you should print it out "
|
||||||
|
"with `print(...)`."
|
||||||
|
)
|
||||||
|
python_repl: PythonREPL = Field(default_factory=_get_default_python_repl)
|
||||||
|
sanitize_input: bool = True
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Use the tool."""
|
||||||
|
if self.sanitize_input:
|
||||||
|
query = sanitize_input(query)
|
||||||
|
return self.python_repl.run(query)
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
if self.sanitize_input:
|
||||||
|
query = sanitize_input(query)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
result = await loop.run_in_executor(None, self.run, query)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class PythonInputs(BaseModel):
|
||||||
|
query: str = Field(description="code snippet to run")
|
||||||
|
|
||||||
|
|
||||||
|
class PythonAstREPLTool(BaseTool):
|
||||||
|
"""A tool for running python code in a REPL."""
|
||||||
|
|
||||||
|
name: str = "python_repl_ast"
|
||||||
|
description: str = (
|
||||||
|
"A Python shell. Use this to execute python commands. "
|
||||||
|
"Input should be a valid python command. "
|
||||||
|
"When using this tool, sometimes output is abbreviated - "
|
||||||
|
"make sure it does not look abbreviated before using it in your answer."
|
||||||
|
)
|
||||||
|
globals: Optional[Dict] = Field(default_factory=dict)
|
||||||
|
locals: Optional[Dict] = Field(default_factory=dict)
|
||||||
|
sanitize_input: bool = True
|
||||||
|
args_schema: Type[BaseModel] = PythonInputs
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def validate_python_version(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate valid python version."""
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
raise ValueError(
|
||||||
|
"This tool relies on Python 3.9 or higher "
|
||||||
|
"(as it uses new functionality in the `ast` module, "
|
||||||
|
f"you have Python version: {sys.version}"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Use the tool."""
|
||||||
|
try:
|
||||||
|
if self.sanitize_input:
|
||||||
|
query = sanitize_input(query)
|
||||||
|
tree = ast.parse(query)
|
||||||
|
module = ast.Module(tree.body[:-1], type_ignores=[])
|
||||||
|
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
|
||||||
|
module_end = ast.Module(tree.body[-1:], type_ignores=[])
|
||||||
|
module_end_str = ast.unparse(module_end) # type: ignore
|
||||||
|
io_buffer = StringIO()
|
||||||
|
try:
|
||||||
|
with redirect_stdout(io_buffer):
|
||||||
|
ret = eval(module_end_str, self.globals, self.locals)
|
||||||
|
if ret is None:
|
||||||
|
return io_buffer.getvalue()
|
||||||
|
else:
|
||||||
|
return ret
|
||||||
|
except Exception:
|
||||||
|
with redirect_stdout(io_buffer):
|
||||||
|
exec(module_end_str, self.globals, self.locals)
|
||||||
|
return io_buffer.getvalue()
|
||||||
|
except Exception as e:
|
||||||
|
return "{}: {}".format(type(e).__name__, str(e))
|
||||||
|
|
||||||
|
async def _arun(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
result = await loop.run_in_executor(None, self._run, query)
|
||||||
|
|
||||||
|
return result
|
71
libs/experimental/langchain_experimental/utilities/python.py
Normal file
71
libs/experimental/langchain_experimental/utilities/python.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
from io import StringIO
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from langchain.pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=None)
|
||||||
|
def warn_once() -> None:
|
||||||
|
"""Warn once about the dangers of PythonREPL."""
|
||||||
|
logger.warning("Python REPL can execute arbitrary code. Use with caution.")
|
||||||
|
|
||||||
|
|
||||||
|
class PythonREPL(BaseModel):
|
||||||
|
"""Simulates a standalone Python REPL."""
|
||||||
|
|
||||||
|
globals: Optional[Dict] = Field(default_factory=dict, alias="_globals")
|
||||||
|
locals: Optional[Dict] = Field(default_factory=dict, alias="_locals")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def worker(
|
||||||
|
cls,
|
||||||
|
command: str,
|
||||||
|
globals: Optional[Dict],
|
||||||
|
locals: Optional[Dict],
|
||||||
|
queue: multiprocessing.Queue,
|
||||||
|
) -> None:
|
||||||
|
old_stdout = sys.stdout
|
||||||
|
sys.stdout = mystdout = StringIO()
|
||||||
|
try:
|
||||||
|
exec(command, globals, locals)
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
queue.put(mystdout.getvalue())
|
||||||
|
except Exception as e:
|
||||||
|
sys.stdout = old_stdout
|
||||||
|
queue.put(repr(e))
|
||||||
|
|
||||||
|
def run(self, command: str, timeout: Optional[int] = None) -> str:
|
||||||
|
"""Run command with own globals/locals and returns anything printed.
|
||||||
|
Timeout after the specified number of seconds."""
|
||||||
|
|
||||||
|
# Warn against dangers of PythonREPL
|
||||||
|
warn_once()
|
||||||
|
|
||||||
|
queue: multiprocessing.Queue = multiprocessing.Queue()
|
||||||
|
|
||||||
|
# Only use multiprocessing if we are enforcing a timeout
|
||||||
|
if timeout is not None:
|
||||||
|
# create a Process
|
||||||
|
p = multiprocessing.Process(
|
||||||
|
target=self.worker, args=(command, self.globals, self.locals, queue)
|
||||||
|
)
|
||||||
|
|
||||||
|
# start it
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
# wait for the process to finish or kill it after timeout seconds
|
||||||
|
p.join(timeout)
|
||||||
|
|
||||||
|
if p.is_alive():
|
||||||
|
p.terminate()
|
||||||
|
return "Execution timed out"
|
||||||
|
else:
|
||||||
|
self.worker(command, self.globals, self.locals, queue)
|
||||||
|
# get the result from the worker function
|
||||||
|
return queue.get()
|
112
libs/experimental/tests/unit_tests/python/test_python_1.py
Normal file
112
libs/experimental/tests/unit_tests/python/test_python_1.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
"""Test functionality of Python REPL."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_experimental.tools.python.tool import PythonAstREPLTool, PythonREPLTool
|
||||||
|
from langchain_experimental.utilities.python import PythonREPL
|
||||||
|
|
||||||
|
_SAMPLE_CODE = """
|
||||||
|
```
|
||||||
|
def multiply():
|
||||||
|
print(5*6)
|
||||||
|
multiply()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
_AST_SAMPLE_CODE = """
|
||||||
|
```
|
||||||
|
def multiply():
|
||||||
|
return(5*6)
|
||||||
|
multiply()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
_AST_SAMPLE_CODE_EXECUTE = """
|
||||||
|
```
|
||||||
|
def multiply(a, b):
|
||||||
|
return(5*6)
|
||||||
|
a = 5
|
||||||
|
b = 6
|
||||||
|
|
||||||
|
multiply(a, b)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl() -> None:
|
||||||
|
"""Test functionality when globals/locals are not provided."""
|
||||||
|
repl = PythonREPL()
|
||||||
|
|
||||||
|
# Run a simple initial command.
|
||||||
|
repl.run("foo = 1")
|
||||||
|
assert repl.locals is not None
|
||||||
|
assert repl.locals["foo"] == 1
|
||||||
|
|
||||||
|
# Now run a command that accesses `foo` to make sure it still has it.
|
||||||
|
repl.run("bar = foo * 2")
|
||||||
|
assert repl.locals is not None
|
||||||
|
assert repl.locals["bar"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_no_previous_variables() -> None:
|
||||||
|
"""Test that it does not have access to variables created outside the scope."""
|
||||||
|
foo = 3 # noqa: F841
|
||||||
|
repl = PythonREPL()
|
||||||
|
output = repl.run("print(foo)")
|
||||||
|
assert output == """NameError("name 'foo' is not defined")"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_pass_in_locals() -> None:
|
||||||
|
"""Test functionality when passing in locals."""
|
||||||
|
_locals = {"foo": 4}
|
||||||
|
repl = PythonREPL(_locals=_locals)
|
||||||
|
repl.run("bar = foo * 2")
|
||||||
|
assert repl.locals is not None
|
||||||
|
assert repl.locals["bar"] == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_functionality() -> None:
|
||||||
|
"""Test correct functionality."""
|
||||||
|
chain = PythonREPL()
|
||||||
|
code = "print(1 + 1)"
|
||||||
|
output = chain.run(code)
|
||||||
|
assert output == "2\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_functionality_multiline() -> None:
|
||||||
|
"""Test correct functionality for ChatGPT multiline commands."""
|
||||||
|
chain = PythonREPL()
|
||||||
|
tool = PythonREPLTool(python_repl=chain)
|
||||||
|
output = tool.run(_SAMPLE_CODE)
|
||||||
|
assert output == "30\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_ast_repl_multiline() -> None:
|
||||||
|
"""Test correct functionality for ChatGPT multiline commands."""
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
pytest.skip("Python 3.9+ is required for this test")
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
output = tool.run(_AST_SAMPLE_CODE)
|
||||||
|
assert output == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_ast_repl_multi_statement() -> None:
|
||||||
|
"""Test correct functionality for ChatGPT multi statement commands."""
|
||||||
|
if sys.version_info < (3, 9):
|
||||||
|
pytest.skip("Python 3.9+ is required for this test")
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
output = tool.run(_AST_SAMPLE_CODE_EXECUTE)
|
||||||
|
assert output == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_function() -> None:
|
||||||
|
"""Test correct functionality."""
|
||||||
|
chain = PythonREPL()
|
||||||
|
code = "def add(a, b): " " return a + b"
|
||||||
|
output = chain.run(code)
|
||||||
|
assert output == ""
|
||||||
|
|
||||||
|
code = "print(add(1, 2))"
|
||||||
|
output = chain.run(code)
|
||||||
|
assert output == "3\n"
|
164
libs/experimental/tests/unit_tests/python/test_python_2.py
Normal file
164
libs/experimental/tests/unit_tests/python/test_python_2.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
"""Test Python REPL Tools."""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain_experimental.tools.python.tool import (
|
||||||
|
PythonAstREPLTool,
|
||||||
|
PythonREPLTool,
|
||||||
|
sanitize_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_tool_single_input() -> None:
|
||||||
|
"""Test that the python REPL tool works with a single input."""
|
||||||
|
tool = PythonREPLTool()
|
||||||
|
assert tool.is_single_input
|
||||||
|
assert int(tool.run("print(1 + 1)").strip()) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_print() -> None:
|
||||||
|
program = """
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
print("The dot product is {:d}.".format(dot_product))
|
||||||
|
"""
|
||||||
|
tool = PythonREPLTool()
|
||||||
|
assert tool.run(program) == "The dot product is 32.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_tool_single_input() -> None:
|
||||||
|
"""Test that the python REPL tool works with a single input."""
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.is_single_input
|
||||||
|
assert tool.run("1 + 1") == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_return() -> None:
|
||||||
|
program = """
|
||||||
|
```
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
int(dot_product)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == 32
|
||||||
|
|
||||||
|
program = """
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
v1 = np.array([1, 2, 3])
|
||||||
|
v2 = np.array([4, 5, 6])
|
||||||
|
dot_product = np.dot(v1, v2)
|
||||||
|
int(dot_product)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
assert tool.run(program) == 32
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_print() -> None:
|
||||||
|
program = """python
|
||||||
|
string = "racecar"
|
||||||
|
if string == string[::-1]:
|
||||||
|
print(string, "is a palindrome")
|
||||||
|
else:
|
||||||
|
print(string, "is not a palindrome")"""
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "racecar is a palindrome\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_repl_print_python_backticks() -> None:
|
||||||
|
program = "`print('`python` is a great language.')`"
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "`python` is a great language.\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_raise_exception() -> None:
|
||||||
|
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
|
||||||
|
program = """
|
||||||
|
import pandas as pd
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df['Gender']
|
||||||
|
"""
|
||||||
|
tool = PythonAstREPLTool(locals={"data": data})
|
||||||
|
expected_outputs = (
|
||||||
|
"KeyError: 'Gender'",
|
||||||
|
"ModuleNotFoundError: No module named 'pandas'",
|
||||||
|
)
|
||||||
|
assert tool.run(program) in expected_outputs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_print() -> None:
|
||||||
|
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "The square of 3 is 9.00\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_return() -> None:
|
||||||
|
arr = np.array([1, 2, 3, 4, 5])
|
||||||
|
tool = PythonAstREPLTool(locals={"arr": arr})
|
||||||
|
program = "`(arr**2).sum() # Returns sum of squares`"
|
||||||
|
assert tool.run(program) == 55
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_python_ast_repl_one_line_exception() -> None:
|
||||||
|
program = "[1, 2, 3][4]"
|
||||||
|
tool = PythonAstREPLTool()
|
||||||
|
assert tool.run(program) == "IndexError: list index out of range"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_input() -> None:
|
||||||
|
query = """
|
||||||
|
```
|
||||||
|
p = 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
query = """
|
||||||
|
```python
|
||||||
|
p = 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
||||||
|
|
||||||
|
query = """
|
||||||
|
p = 5
|
||||||
|
"""
|
||||||
|
expected = "p = 5"
|
||||||
|
actual = sanitize_input(query)
|
||||||
|
assert expected == actual
|
Loading…
Reference in New Issue
Block a user