mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
Finish
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""Agent for working with pandas objects."""
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import (
|
||||
@@ -34,8 +34,8 @@ def _get_multi_prompt(
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
num_dfs = len(dfs)
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
@@ -57,9 +57,9 @@ def _get_multi_prompt(
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = (
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals=df_locals)]
|
||||
if not use_pandas_eval_tool
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs=df_locals, model=llm)]
|
||||
)
|
||||
|
||||
@@ -84,8 +84,8 @@ def _get_single_prompt(
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
include_df_head = True
|
||||
@@ -104,9 +104,9 @@ def _get_single_prompt(
|
||||
if prefix is None:
|
||||
prefix = PREFIX
|
||||
|
||||
tools = (
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals={"df": df})]
|
||||
if not use_pandas_eval_tool
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs={"df": df}, model=llm)]
|
||||
)
|
||||
|
||||
@@ -130,8 +130,8 @@ def _get_prompt_and_tools(
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
@@ -156,7 +156,7 @@ def _get_prompt_and_tools(
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
@@ -169,7 +169,7 @@ def _get_prompt_and_tools(
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
|
||||
|
||||
@@ -180,8 +180,8 @@ def _get_functions_single_prompt(
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
@@ -198,9 +198,9 @@ def _get_functions_single_prompt(
|
||||
if prefix is None:
|
||||
prefix = PREFIX_FUNCTIONS
|
||||
|
||||
tools = (
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals={"df": df})]
|
||||
if not use_pandas_eval_tool
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs={"df": df}, model=llm)]
|
||||
)
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
@@ -215,8 +215,8 @@ def _get_functions_multi_prompt(
|
||||
suffix: Optional[str] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
if suffix is not None:
|
||||
suffix_to_use = suffix
|
||||
if include_df_in_prompt:
|
||||
@@ -241,9 +241,9 @@ def _get_functions_multi_prompt(
|
||||
df_locals = {}
|
||||
for i, dataframe in enumerate(dfs):
|
||||
df_locals[f"df{i + 1}"] = dataframe
|
||||
tools = (
|
||||
tools: List[Union[PythonAstREPLTool, PandasEvalTool]] = (
|
||||
[PythonAstREPLTool(locals=df_locals)]
|
||||
if not use_pandas_eval_tool
|
||||
if not use_sql_eval
|
||||
else [PandasEvalTool(dfs=df_locals, model=llm)]
|
||||
)
|
||||
system_message = SystemMessage(content=prefix + suffix_to_use)
|
||||
@@ -259,8 +259,8 @@ def _get_functions_prompt_and_tools(
|
||||
input_variables: Optional[List[str]] = None,
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
use_pandas_eval_tool: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
|
||||
use_sql_eval: bool = False,
|
||||
) -> Tuple[BasePromptTemplate, List[Union[PythonAstREPLTool, PandasEvalTool]]]:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
@@ -286,7 +286,7 @@ def _get_functions_prompt_and_tools(
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
else:
|
||||
if not isinstance(df, pd.DataFrame):
|
||||
@@ -298,7 +298,7 @@ def _get_functions_prompt_and_tools(
|
||||
suffix=suffix,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
|
||||
|
||||
@@ -319,10 +319,17 @@ def create_pandas_dataframe_agent(
|
||||
include_df_in_prompt: Optional[bool] = True,
|
||||
number_of_head_rows: int = 5,
|
||||
extra_tools: Sequence[BaseTool] = (),
|
||||
use_pandas_eval_tool: bool = True,
|
||||
use_sql_eval: bool = True,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> AgentExecutor:
|
||||
"""Construct a pandas agent from an LLM and dataframe."""
|
||||
"""Construct a pandas agent from an LLM and dataframe.
|
||||
|
||||
Args:
|
||||
use_sql_eval: Whether to evaluate pandas code using SQL translation.
|
||||
Unlike the default Python REPL, this doesn't execute
|
||||
arbitrary Python code, but requires the `duckdb` package.
|
||||
When `False`, it uses Python REPL tool.
|
||||
"""
|
||||
agent: BaseSingleActionAgent
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
prompt, base_tools = _get_prompt_and_tools(
|
||||
@@ -333,7 +340,7 @@ def create_pandas_dataframe_agent(
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
llm_chain = LLMChain(
|
||||
@@ -351,12 +358,13 @@ def create_pandas_dataframe_agent(
|
||||
elif agent_type == AgentType.OPENAI_FUNCTIONS:
|
||||
_prompt, base_tools = _get_functions_prompt_and_tools(
|
||||
df,
|
||||
llm,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
input_variables=input_variables,
|
||||
include_df_in_prompt=include_df_in_prompt,
|
||||
number_of_head_rows=number_of_head_rows,
|
||||
use_pandas_eval_tool=use_pandas_eval_tool,
|
||||
use_sql_eval=use_sql_eval,
|
||||
)
|
||||
tools = base_tools + list(extra_tools)
|
||||
agent = OpenAIFunctionsAgent(
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
from typing_extensions import TypeAlias
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.runnable.base import Runnable
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.runnable.base import Runnable
|
||||
from langchain.tools.base import BaseTool
|
||||
|
||||
DF: TypeAlias = Any
|
||||
@@ -21,6 +22,9 @@ def evaluate_sql_on_dfs(sql: str, **dfs: DF) -> DF:
|
||||
"duckdb is required to evaluate SQL queries on pandas dataframes."
|
||||
)
|
||||
|
||||
if not sql:
|
||||
return ""
|
||||
|
||||
locals().update(dfs)
|
||||
conn = duckdb.connect()
|
||||
return conn.execute(sql).fetchall()
|
||||
@@ -30,13 +34,17 @@ def get_pandas_eval_chain(model: BaseLanguageModel, dfs: dict[str, DF]) -> Runna
|
||||
prompt = PromptTemplate.from_template(
|
||||
"""You are an expert data scientist, tasked with converting python code manipulating pandas dataframes into SQL queries.
|
||||
|
||||
You should write a SQL query that will return the same result as the python code below/There are SQL tables with the same name as any Pandas dataframe in the code.
|
||||
You should write a SQL query that will return the same result as the python code below/
|
||||
There are SQL tables with the same name as any Pandas dataframe in the code ({tables}).
|
||||
|
||||
You are given the following python code:
|
||||
|
||||
{input}
|
||||
|
||||
SQL query:"""
|
||||
If the python code is not valid pandas code, you should return an empty string.
|
||||
|
||||
SQL query:""", # noqa: E501
|
||||
partial_variables={"tables": str(list(dfs.keys()))},
|
||||
)
|
||||
|
||||
return prompt | model | StrOutputParser() | partial(evaluate_sql_on_dfs, **dfs)
|
||||
@@ -54,7 +62,7 @@ class PandasEvalTool(BaseTool):
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
run_manager: CallbackManagerForToolRun,
|
||||
) -> Any:
|
||||
chain = get_pandas_eval_chain(self.model, self.dfs)
|
||||
return chain.invoke({"input": query}, {"callbacks": run_manager.get_child()})
|
||||
|
||||
Reference in New Issue
Block a user