mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
SqlDatabaseToolkit should have custom llm for QueryChecke… (#2676)
…rTool (#2655) --------- Co-authored-by: Rushabh Agarwal <26388764+rushout09@users.noreply.github.com>
This commit is contained in:
parent
8d3b059332
commit
e23a596a18
@ -44,7 +44,7 @@ def create_sql_agent(
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=toolkit.get_tools(),
|
||||
tools=tools,
|
||||
verbose=verbose,
|
||||
max_iterations=max_iterations,
|
||||
early_stopping_method=early_stopping_method,
|
||||
|
@ -4,6 +4,8 @@ from typing import List
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.tools.sql_database.tool import (
|
||||
@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
"""Toolkit for interacting with SQL databases."""
|
||||
|
||||
db: SQLDatabase = Field(exclude=True)
|
||||
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit):
|
||||
QuerySQLDataBaseTool(db=self.db),
|
||||
InfoSQLDatabaseTool(db=self.db),
|
||||
ListSQLDatabaseTool(db=self.db),
|
||||
QueryCheckerTool(db=self.db),
|
||||
QueryCheckerTool(db=self.db, llm=self.llm),
|
||||
]
|
||||
|
@ -3,9 +3,9 @@
|
||||
from pydantic import BaseModel, Extra, Field, validator
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
@ -80,11 +80,12 @@ class QueryCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/"""
|
||||
|
||||
template: str = QUERY_CHECKER
|
||||
llm: BaseLLM
|
||||
llm_chain: LLMChain = Field(
|
||||
default_factory=lambda: LLMChain(
|
||||
llm=OpenAI(temperature=0),
|
||||
llm=QueryCheckerTool.llm,
|
||||
prompt=PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
||||
template=QueryCheckerTool.template, input_variables=["query", "dialect"]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user