Compare commits

...

5 Commits

Author SHA1 Message Date
Nuno Campos
cb8ecaccc7 Dumbness 2023-10-03 13:33:54 +01:00
Nuno Campos
1586a4893d Lint 2023-10-03 13:30:37 +01:00
Nuno Campos
d402e8b214 Lint 2023-10-03 13:25:19 +01:00
Nuno Campos
5cbabbd2c1 Finish 2023-10-03 13:21:57 +01:00
Nuno Campos
869ef49699 Add safe pandas eval tool 2023-10-03 11:53:29 +01:00
3 changed files with 132 additions and 12 deletions

View File

@@ -1,5 +1,5 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.pandas.prompt import (
@@ -22,17 +22,20 @@ from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import SystemMessage
from langchain.tools import BaseTool
from langchain.tools.pandas_eval.pandas_eval import PandasEvalTool
from langchain.tools.python.tool import PythonAstREPLTool
def _get_multi_prompt(
dfs: List[Any],
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
num_dfs = len(dfs)
if suffix is not None:
suffix_to_use = suffix
@@ -54,7 +57,11 @@ def _get_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals=df_locals)]
if not use_sql_eval
else [PandasEvalTool(dfs=df_locals, model=llm)]
)
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
@@ -71,12 +78,14 @@ def _get_multi_prompt(
def _get_single_prompt(
df: Any,
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
include_df_head = True
@@ -95,7 +104,11 @@ def _get_single_prompt(
if prefix is None:
prefix = PREFIX
tools = [PythonAstREPLTool(locals={"df": df})]
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals={"df": df})]
if not use_sql_eval
else [PandasEvalTool(dfs={"df": df}, model=llm)]
)
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
@@ -111,12 +124,14 @@ def _get_single_prompt(
def _get_prompt_and_tools(
df: Any,
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
try:
import pandas as pd
@@ -135,32 +150,38 @@ def _get_prompt_and_tools(
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_multi_prompt(
df,
llm,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_single_prompt(
df,
llm,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
def _get_functions_single_prompt(
df: Any,
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
@@ -177,7 +198,11 @@ def _get_functions_single_prompt(
if prefix is None:
prefix = PREFIX_FUNCTIONS
tools = [PythonAstREPLTool(locals={"df": df})]
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals={"df": df})]
if not use_sql_eval
else [PandasEvalTool(dfs={"df": df}, model=llm)]
)
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
@@ -185,11 +210,13 @@ def _get_functions_single_prompt(
def _get_functions_multi_prompt(
dfs: Any,
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
@@ -214,7 +241,11 @@ def _get_functions_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals=df_locals)]
if not use_sql_eval
else [PandasEvalTool(dfs=df_locals, model=llm)]
)
system_message = SystemMessage(content=prefix + suffix_to_use)
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
return prompt, tools
@@ -222,12 +253,14 @@ def _get_functions_multi_prompt(
def _get_functions_prompt_and_tools(
df: Any,
llm: BaseLanguageModel,
prefix: Optional[str] = None,
suffix: Optional[str] = None,
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
try:
import pandas as pd
@@ -248,20 +281,24 @@ def _get_functions_prompt_and_tools(
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_multi_prompt(
df,
llm,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
else:
if not isinstance(df, pd.DataFrame):
raise ValueError(f"Expected pandas object, got {type(df)}")
return _get_functions_single_prompt(
df,
llm,
prefix=prefix,
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
@@ -282,18 +319,28 @@ def create_pandas_dataframe_agent(
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
extra_tools: Sequence[BaseTool] = (),
use_sql_eval: bool = True,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
"""Construct a pandas agent from an LLM and dataframe.
Args:
use_sql_eval: Whether to evaluate pandas code using SQL translation.
Unlike the default Python REPL, this doesn't execute
arbitrary Python code, but requires the `duckdb` package.
When `False`, it uses Python REPL tool.
"""
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, base_tools = _get_prompt_and_tools(
df,
llm,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
tools = base_tools + list(extra_tools)
llm_chain = LLMChain(
@@ -311,11 +358,13 @@ def create_pandas_dataframe_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS:
_prompt, base_tools = _get_functions_prompt_and_tools(
df,
llm,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_sql_eval=use_sql_eval,
)
tools = base_tools + list(extra_tools)
agent = OpenAIFunctionsAgent(

View File

@@ -0,0 +1,71 @@
from functools import partial
from typing import Any, Dict, Optional
from typing_extensions import TypeAlias
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.prompts.prompt import PromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable.base import Runnable
from langchain.tools.base import BaseTool
DF: TypeAlias = Any
def evaluate_sql_on_dfs(sql: str, **dfs: DF) -> DF:
"""Evaluate a SQL query on a pandas dataframe."""
try:
import duckdb
except ImportError:
raise ImportError(
"duckdb is required to evaluate SQL queries on pandas dataframes."
)
if not sql:
return ""
locals().update(dfs)
conn = duckdb.connect()
return conn.execute(sql).fetchall()
def get_pandas_eval_chain(model: BaseLanguageModel, dfs: Dict[str, DF]) -> Runnable:
prompt = PromptTemplate.from_template(
"""You are an expert data scientist, tasked with converting python code manipulating pandas dataframes into SQL queries.
You should write a SQL query that will return the same result as the python code below/
There are SQL tables with the same name as any Pandas dataframe in the code ({tables}).
You are given the following python code:
{input}
If the python code is not valid pandas code, you should return an empty string.
SQL query:""", # noqa: E501
partial_variables={"tables": str(list(dfs.keys()))},
)
return prompt | model | StrOutputParser() | partial(evaluate_sql_on_dfs, **dfs)
class PandasEvalTool(BaseTool):
name: str = "pandas_eval"
description: str = "Evaluate pandas code against one or more dataframes."
dfs: Dict[str, DF]
model: BaseLanguageModel
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Any:
chain = get_pandas_eval_chain(self.model, self.dfs)
return chain.invoke(
{"input": query},
{"callbacks": run_manager.get_child()} if run_manager else {},
)