mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
Add Multi-CSV/DF support in CSV and DataFrame Toolkits (#5009)
Add Multi-CSV/DF support in CSV and DataFrame Toolkits * CSV and DataFrame toolkits now accept list of CSVs/DFs * Add default prompts for many dataframes in `pandas_dataframe` toolkit Fixes #1958 Potentially fixes #4423 ## Testing * Add single and multi-dataframe integration tests for `pandas_dataframe` toolkit with permutations of `include_df_in_prompt` * Add single and multi-CSV integration tests for csv toolkit --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
57
tests/integration_tests/agent/test_csv_agent.py
Normal file
57
tests/integration_tests/agent/test_csv_agent.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from _pytest.tmpdir import TempPathFactory
|
||||
from pandas import DataFrame
|
||||
|
||||
from langchain.agents import create_csv_agent
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def csv(tmp_path_factory: TempPathFactory) -> DataFrame:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
filename = str(tmp_path_factory.mktemp("data") / "test.csv")
|
||||
df.to_csv(filename)
|
||||
return filename
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def csv_list(tmp_path_factory: TempPathFactory) -> DataFrame:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
filename1 = str(tmp_path_factory.mktemp("data") / "test1.csv")
|
||||
df1.to_csv(filename1)
|
||||
|
||||
random_data = np.random.rand(2, 2)
|
||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
||||
filename2 = str(tmp_path_factory.mktemp("data") / "test2.csv")
|
||||
df2.to_csv(filename2)
|
||||
|
||||
return [filename1, filename2]
|
||||
|
||||
|
||||
def test_csv_agent_creation(csv: str) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
|
||||
|
||||
def test_single_csv(csv: str) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("How many rows in the csv? Give me a number.")
|
||||
result = re.search(r".*(4).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_csv(csv_list: list) -> None:
|
||||
agent = create_csv_agent(OpenAI(temperature=0), csv_list, verbose=True)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("How many combined rows in the two csvs? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
@@ -16,6 +16,17 @@ def df() -> DataFrame:
|
||||
return df
|
||||
|
||||
|
||||
# Figure out type hint here
|
||||
@pytest.fixture(scope="module")
|
||||
def df_list() -> list:
|
||||
random_data = np.random.rand(4, 4)
|
||||
df1 = DataFrame(random_data, columns=["name", "age", "food", "sport"])
|
||||
random_data = np.random.rand(2, 2)
|
||||
df2 = DataFrame(random_data, columns=["name", "height"])
|
||||
df_list = [df1, df2]
|
||||
return df_list
|
||||
|
||||
|
||||
def test_pandas_agent_creation(df: DataFrame) -> None:
|
||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
@@ -28,3 +39,32 @@ def test_data_reading(df: DataFrame) -> None:
|
||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_data_reading_no_df_in_prompt(df: DataFrame) -> None:
|
||||
agent = create_pandas_dataframe_agent(
|
||||
OpenAI(temperature=0), df, include_df_in_prompt=False
|
||||
)
|
||||
assert isinstance(agent, AgentExecutor)
|
||||
response = agent.run("how many rows in df? Give me a number.")
|
||||
result = re.search(rf".*({df.shape[0]}).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_df(df_list: list) -> None:
|
||||
agent = create_pandas_dataframe_agent(OpenAI(temperature=0), df_list, verbose=True)
|
||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
||||
|
||||
def test_multi_df_no_df_in_prompt(df_list: list) -> None:
|
||||
agent = create_pandas_dataframe_agent(
|
||||
OpenAI(temperature=0), df_list, include_df_in_prompt=False
|
||||
)
|
||||
response = agent.run("how many total rows in the two dataframes? Give me a number.")
|
||||
result = re.search(r".*(6).*", response)
|
||||
assert result is not None
|
||||
assert result.group(1) is not None
|
||||
|
Reference in New Issue
Block a user