mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
community[patch]: avoid executing toolkit.get_context()
when not necessary (#19762)
If `prompt` is passed into `create_sql_agent()`, then `toolkit.get_context()` shouldn't be executed against the database unless relevant prompt variables (`table_info` or `table_names`) are present .
This commit is contained in:
parent
ec7a59c96c
commit
2319212d54
@ -150,17 +150,18 @@ def create_sql_agent(
|
||||
prompt = prompt.partial(top_k=str(top_k))
|
||||
if "dialect" in prompt.input_variables:
|
||||
prompt = prompt.partial(dialect=toolkit.dialect)
|
||||
db_context = toolkit.get_context()
|
||||
if "table_info" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_info=db_context["table_info"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
|
||||
]
|
||||
if "table_names" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_names=db_context["table_names"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
|
||||
]
|
||||
if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
|
||||
db_context = toolkit.get_context()
|
||||
if "table_info" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_info=db_context["table_info"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
|
||||
]
|
||||
if "table_names" in prompt.input_variables:
|
||||
prompt = prompt.partial(table_names=db_context["table_names"])
|
||||
tools = [
|
||||
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
|
||||
]
|
||||
|
||||
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
|
||||
if prompt is None:
|
||||
|
Loading…
Reference in New Issue
Block a user