mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
make it so everything goes through generate, which removes the need for two types of caches
79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
"""Chain for interacting with SQL Database."""
|
|
from typing import Dict, List
|
|
|
|
from pydantic import BaseModel, Extra
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.sql_database.prompt import PROMPT
|
|
from langchain.input import print_text
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
|
|
class SQLDatabaseChain(Chain, BaseModel):
|
|
"""Chain for interacting with SQL Database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
|
db = SQLDatabase(...)
|
|
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
|
"""
|
|
|
|
llm: BaseLLM
|
|
"""LLM wrapper to use."""
|
|
database: SQLDatabase
|
|
"""SQL Database to connect to."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
input_text = f"{inputs[self.input_key]} \nSQLQuery:"
|
|
if self.verbose:
|
|
print_text(input_text)
|
|
llm_inputs = {
|
|
"input": input_text,
|
|
"dialect": self.database.dialect,
|
|
"table_info": self.database.table_info,
|
|
"stop": ["\nSQLResult:"],
|
|
}
|
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
|
if self.verbose:
|
|
print_text(sql_cmd, color="green")
|
|
result = self.database.run(sql_cmd)
|
|
if self.verbose:
|
|
print_text("\nSQLResult: ")
|
|
print_text(result, color="yellow")
|
|
print_text("\nAnswer:")
|
|
input_text += f"{sql_cmd}\nSQLResult: {result}\nAnswer:"
|
|
llm_inputs["input"] = input_text
|
|
final_result = llm_chain.predict(**llm_inputs)
|
|
if self.verbose:
|
|
print_text(final_result, color="green")
|
|
return {self.output_key: final_result}
|