community[patch]: avoid executing toolkit.get_context() when not necessary (#19762)

If `prompt` is passed into `create_sql_agent()`, then
`toolkit.get_context()` shouldn't be executed against the database
unless relevant prompt variables (`table_info` or `table_names`) are
present .
This commit is contained in:
Arturs Konfino 2024-03-29 18:42:21 +02:00 committed by GitHub
parent ec7a59c96c
commit 2319212d54
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -150,17 +150,18 @@ def create_sql_agent(
prompt = prompt.partial(top_k=str(top_k)) prompt = prompt.partial(top_k=str(top_k))
if "dialect" in prompt.input_variables: if "dialect" in prompt.input_variables:
prompt = prompt.partial(dialect=toolkit.dialect) prompt = prompt.partial(dialect=toolkit.dialect)
db_context = toolkit.get_context() if any(key in prompt.input_variables for key in ["table_info", "table_names"]):
if "table_info" in prompt.input_variables: db_context = toolkit.get_context()
prompt = prompt.partial(table_info=db_context["table_info"]) if "table_info" in prompt.input_variables:
tools = [ prompt = prompt.partial(table_info=db_context["table_info"])
tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool) tools = [
] tool for tool in tools if not isinstance(tool, InfoSQLDatabaseTool)
if "table_names" in prompt.input_variables: ]
prompt = prompt.partial(table_names=db_context["table_names"]) if "table_names" in prompt.input_variables:
tools = [ prompt = prompt.partial(table_names=db_context["table_names"])
tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool) tools = [
] tool for tool in tools if not isinstance(tool, ListSQLDatabaseTool)
]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
if prompt is None: if prompt is None: