mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-21 21:56:38 +00:00
Compare commits
5 Commits
replace_ap
...
nc/pandas-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb8ecaccc7 | ||
|
|
1586a4893d | ||
|
|
d402e8b214 | ||
|
|
5cbabbd2c1 | ||
|
|
869ef49699 |
@@ -1,5 +1,5 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||
@@ -22,17 +22,20 @@ from langchain.schema import BasePromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.pandas_eval.pandas_eval import PandasEvalTool
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
def _get_multi_prompt(
|
||||
dfs: List[Any],
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
num_dfs = len(dfs)
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
@@ -54,7 +57,11 @@ def _get_multi_prompt(
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals=df_locals)]
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs=df_locals, model=llm)]
|
||||
)
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
@@ -71,12 +78,14 @@ def _get_multi_prompt(
|
||||
|
||||
def _get_single_prompt(
|
||||
df: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_df_head = True
|
||||
@@ -95,7 +104,11 @@ def _get_single_prompt(
|
||||
if prefix is None:
|
||||
prefix = PREFIX
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals={"df": df})]
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs={"df": df}, model=llm)]
|
||||
)
|
||||
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
|
||||
@@ -111,12 +124,14 @@ def _get_single_prompt(
|
||||
|
||||
def _get_prompt_and_tools(
|
||||
df: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
@@ -135,32 +150,38 @@ def _get_prompt_and_tools(
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_multi_prompt(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_single_prompt(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
|
||||
|
||||
def _get_functions_single_prompt(
|
||||
df: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
@@ -177,7 +198,11 @@ def _get_functions_single_prompt(
|
||||
if prefix is None:
|
||||
prefix = PREFIX_FUNCTIONS
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals={"df": df})]
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs={"df": df}, model=llm)]
|
||||
)
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
@@ -185,11 +210,13 @@ def _get_functions_single_prompt(
|
||||
|
||||
def _get_functions_multi_prompt(
|
||||
dfs: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
@@ -214,7 +241,11 @@ def _get_functions_multi_prompt(
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = [PythonAstREPLTool(locals=df_locals)]
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals=df_locals)]
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs=df_locals, model=llm)]
|
||||
)
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message)
|
||||
return prompt, tools
|
||||
@@ -222,12 +253,14 @@ def _get_functions_multi_prompt(
|
||||
|
||||
def _get_functions_prompt_and_tools(
|
||||
df: Any,
|
||||
llm: BaseLanguageModel,
|
||||
prefix: Optional[str] = None,
|
||||
suffix: Optional[str] = None,
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
@@ -248,20 +281,24 @@ def _get_functions_prompt_and_tools(
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_multi_prompt(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
raise ValueError(f"Expected pandas object, got {type(df)}")
|
||||
return _get_functions_single_prompt(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
|
||||
|
||||
@@ -282,18 +319,28 @@ def create_pandas_dataframe_agent(
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
use_sql_eval: bool = True,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a pandas agent from an LLM and dataframe."""
|
||||
"""Construct a pandas agent from an LLM and dataframe.
|
||||
|
||||
Args:
|
||||
use_sql_eval: Whether to evaluate pandas code using SQL translation.
|
||||
Unlike the default Python REPL, this doesn't execute
|
||||
arbitrary Python code, but requires the `duckdb` package.
|
||||
When `False`, it uses Python REPL tool.
|
||||
"""
|
||||
agent: BaseSingleActionAgent
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt, base_tools = _get_prompt_and_tools(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
llm_chain = LLMChain(
|
||||
@@ -311,11 +358,13 @@ def create_pandas_dataframe_agent(
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
_prompt, base_tools = _get_functions_prompt_and_tools(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
|
||||
71
libs/langchain/langchain/tools/pandas_eval/pandas_eval.py
Normal file
71
libs/langchain/langchain/tools/pandas_eval/pandas_eval.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable.base import Runnable
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
DF: TypeAlias = Any
|
||||
|
||||
|
||||
def evaluate_sql_on_dfs(sql: str, **dfs: DF) -> DF:
|
||||
"""Evaluate a SQL query on a pandas dataframe."""
|
||||
try:
|
||||
import duckdb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"duckdb is required to evaluate SQL queries on pandas dataframes."
|
||||
)
|
||||
|
||||
if not sql:
|
||||
return ""
|
||||
|
||||
locals().update(dfs)
|
||||
conn = duckdb.connect()
|
||||
return conn.execute(sql).fetchall()
|
||||
|
||||
|
||||
def get_pandas_eval_chain(model: BaseLanguageModel, dfs: Dict[str, DF]) -> Runnable:
|
||||
prompt = PromptTemplate.from_template(
|
||||
"""You are an expert data scientist, tasked with converting python code manipulating pandas dataframes into SQL queries.
|
||||
|
||||
You should write a SQL query that will return the same result as the python code below/
|
||||
There are SQL tables with the same name as any Pandas dataframe in the code ({tables}).
|
||||
|
||||
You are given the following python code:
|
||||
|
||||
{input}
|
||||
|
||||
If the python code is not valid pandas code, you should return an empty string.
|
||||
|
||||
SQL query:""", # noqa: E501
|
||||
partial_variables={"tables": str(list(dfs.keys()))},
|
||||
)
|
||||
|
||||
return prompt | model | StrOutputParser() | partial(evaluate_sql_on_dfs, **dfs)
|
||||
|
||||
|
||||
class PandasEvalTool(BaseTool):
|
||||
name: str = "pandas_eval"
|
||||
|
||||
description: str = "Evaluate pandas code against one or more dataframes."
|
||||
|
||||
dfs: Dict[str, DF]
|
||||
|
||||
model: BaseLanguageModel
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> Any:
|
||||
chain = get_pandas_eval_chain(self.model, self.dfs)
|
||||
return chain.invoke(
|
||||
{"input": query},
|
||||
{"callbacks": run_manager.get_child()} if run_manager else {},
|
||||
)
|
||||
Reference in New Issue
Block a user