mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 17:33:53 +00:00
SqlDatabaseToolkit should have custom llm for QueryChecke… (#2676)
…rTool (#2655) --------- Co-authored-by: Rushabh Agarwal <26388764+rushout09@users.noreply.github.com>
This commit is contained in:
parent
8d3b059332
commit
e23a596a18
@ -44,7 +44,7 @@ def create_sql_agent(
|
|||||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||||
return AgentExecutor.from_agent_and_tools(
|
return AgentExecutor.from_agent_and_tools(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tools=toolkit.get_tools(),
|
tools=tools,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
early_stopping_method=early_stopping_method,
|
early_stopping_method=early_stopping_method,
|
||||||
|
@ -4,6 +4,8 @@ from typing import List
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.tools.sql_database.tool import (
|
from langchain.tools.sql_database.tool import (
|
||||||
@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
|||||||
"""Toolkit for interacting with SQL databases."""
|
"""Toolkit for interacting with SQL databases."""
|
||||||
|
|
||||||
db: SQLDatabase = Field(exclude=True)
|
db: SQLDatabase = Field(exclude=True)
|
||||||
|
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dialect(self) -> str:
|
def dialect(self) -> str:
|
||||||
@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit):
|
|||||||
QuerySQLDataBaseTool(db=self.db),
|
QuerySQLDataBaseTool(db=self.db),
|
||||||
InfoSQLDatabaseTool(db=self.db),
|
InfoSQLDatabaseTool(db=self.db),
|
||||||
ListSQLDatabaseTool(db=self.db),
|
ListSQLDatabaseTool(db=self.db),
|
||||||
QueryCheckerTool(db=self.db),
|
QueryCheckerTool(db=self.db, llm=self.llm),
|
||||||
]
|
]
|
||||||
|
@ -3,9 +3,9 @@
|
|||||||
from pydantic import BaseModel, Extra, Field, validator
|
from pydantic import BaseModel, Extra, Field, validator
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.openai import OpenAI
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.sql_database import SQLDatabase
|
from langchain.sql_database import SQLDatabase
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||||
|
|
||||||
@ -80,11 +80,12 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||||
|
|
||||||
template: str = QUERY_CHECKER
|
template: str = QUERY_CHECKER
|
||||||
|
llm: BaseLLM
|
||||||
llm_chain: LLMChain = Field(
|
llm_chain: LLMChain = Field(
|
||||||
default_factory=lambda: LLMChain(
|
default_factory=lambda: LLMChain(
|
||||||
llm=OpenAI(temperature=0),
|
llm=QueryCheckerTool.llm,
|
||||||
prompt=PromptTemplate(
|
prompt=PromptTemplate(
|
||||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
template=QueryCheckerTool.template, input_variables=["query", "dialect"]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user