This commit is contained in:
Nuno Campos
2023-10-03 13:21:57 +01:00
parent 869ef49699
commit 5cbabbd2c1
2 changed files with 52 additions and 36 deletions

View File

@@ -1,5 +1,5 @@
"""Agent for working with pandas objects."""
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
from langchain.agents.agent_toolkits.pandas.prompt import (
@@ -34,8 +34,8 @@ def _get_multi_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
num_dfs = len(dfs)
if suffix is not None:
suffix_to_use = suffix
@@ -57,9 +57,9 @@ def _get_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = (
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals=df_locals)]
if not use_pandas_eval_tool
if not use_sql_eval
else [PandasEvalTool(dfs=df_locals, model=llm)]
)
@@ -84,8 +84,8 @@ def _get_single_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
include_df_head = True
@@ -104,9 +104,9 @@ def _get_single_prompt(
if prefix is None:
prefix = PREFIX
tools = (
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals={"df": df})]
if not use_pandas_eval_tool
if not use_sql_eval
else [PandasEvalTool(dfs={"df": df}, model=llm)]
)
@@ -130,8 +130,8 @@ def _get_prompt_and_tools(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
try:
import pandas as pd
@@ -156,7 +156,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
else:
if not isinstance(df, pd.DataFrame):
@@ -169,7 +169,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
@@ -180,8 +180,8 @@ def _get_functions_single_prompt(
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
@@ -198,9 +198,9 @@ def _get_functions_single_prompt(
if prefix is None:
prefix = PREFIX_FUNCTIONS
tools = (
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals={"df": df})]
if not use_pandas_eval_tool
if not use_sql_eval
else [PandasEvalTool(dfs={"df": df}, model=llm)]
)
system_message = SystemMessage(content=prefix + suffix_to_use)
@@ -215,8 +215,8 @@ def _get_functions_multi_prompt(
suffix: Optional[str] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
if suffix is not None:
suffix_to_use = suffix
if include_df_in_prompt:
@@ -241,9 +241,9 @@ def _get_functions_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = (
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
[PythonAstREPLTool(locals=df_locals)]
if not use_pandas_eval_tool
if not use_sql_eval
else [PandasEvalTool(dfs=df_locals, model=llm)]
)
system_message = SystemMessage(content=prefix + suffix_to_use)
@@ -259,8 +259,8 @@ def _get_functions_prompt_and_tools(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
use_pandas_eval_tool: bool = False,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
use_sql_eval: bool = False,
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
try:
import pandas as pd
@@ -286,7 +286,7 @@ def _get_functions_prompt_and_tools(
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
else:
if not isinstance(df, pd.DataFrame):
@@ -298,7 +298,7 @@ def _get_functions_prompt_and_tools(
suffix=suffix,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
@@ -319,10 +319,17 @@ def create_pandas_dataframe_agent(
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
extra_tools: Sequence[BaseTool] = (),
use_pandas_eval_tool: bool = True,
use_sql_eval: bool = True,
**kwargs: Dict[str, Any],
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
"""Construct a pandas agent from an LLM and dataframe.
Args:
use_sql_eval: Whether to evaluate pandas code using SQL translation.
Unlike the default Python REPL, this doesn't execute
arbitrary Python code, but requires the `duckdb` package.
When `False`, it uses Python REPL tool.
"""
agent: BaseSingleActionAgent
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, base_tools = _get_prompt_and_tools(
@@ -333,7 +340,7 @@ def create_pandas_dataframe_agent(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
tools = base_tools + list(extra_tools)
llm_chain = LLMChain(
@@ -351,12 +358,13 @@ def create_pandas_dataframe_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS:
_prompt, base_tools = _get_functions_prompt_and_tools(
df,
llm,
prefix=prefix,
suffix=suffix,
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
use_pandas_eval_tool=use_pandas_eval_tool,
use_sql_eval=use_sql_eval,
)
tools = base_tools + list(extra_tools)
agent = OpenAIFunctionsAgent(

View File

@@ -1,12 +1,13 @@
from functools import partial
from typing import Any, Optional
from typing_extensions import TypeAlias
from langchain.callbacks.manager import CallbackManagerForToolRun
from typing import Any
from typing_extensions import TypeAlias
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.prompts.prompt import PromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.runnable.base import Runnable
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable.base import Runnable
from langchain.tools.base import BaseTool
DF: TypeAlias = Any
@@ -21,6 +22,9 @@ def evaluate_sql_on_dfs(sql: str, **dfs: DF) -> DF:
"duckdb is required to evaluate SQL queries on pandas dataframes."
)
if not sql:
return ""
locals().update(dfs)
conn = duckdb.connect()
return conn.execute(sql).fetchall()
@@ -30,13 +34,17 @@ def get_pandas_eval_chain(model: BaseLanguageModel, dfs: dict[str, DF]) -> Runna
prompt = PromptTemplate.from_template(
"""You are an expert data scientist, tasked with converting python code manipulating pandas dataframes into SQL queries.
You should write a SQL query that will return the same result as the python code below/There are SQL tables with the same name as any Pandas dataframe in the code.
You should write a SQL query that will return the same result as the python code below/
There are SQL tables with the same name as any Pandas dataframe in the code ({tables}).
You are given the following python code:
{input}
SQL query:"""
If the python code is not valid pandas code, you should return an empty string.
SQL query:""", # noqa: E501
partial_variables={"tables": str(list(dfs.keys()))},
)
return prompt | model | StrOutputParser() | partial(evaluate_sql_on_dfs, **dfs)
@@ -54,7 +62,7 @@ class PandasEvalTool(BaseTool):
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
run_manager: CallbackManagerForToolRun,
) -> Any:
chain = get_pandas_eval_chain(self.model, self.dfs)
return chain.invoke({"input": query}, {"callbacks": run_manager.get_child()})