mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
consolidating logic for when a chain is able to run with single input text, single output text open to feedback on naming, logic, usefulness
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
"""Chain for interacting with SQL Database."""
|
|
from typing import Dict, List
|
|
|
|
from pydantic import BaseModel, Extra
|
|
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.sql_database.prompt import PROMPT
|
|
from langchain.input import ChainedInput
|
|
from langchain.llms.base import LLM
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
|
|
class SQLDatabaseChain(Chain, BaseModel):
|
|
"""Chain for interacting with SQL Database.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain import SQLDatabaseChain, OpenAI, SQLDatabase
|
|
db = SQLDatabase(...)
|
|
db_chain = SelfAskWithSearchChain(llm=OpenAI(), database=db)
|
|
"""
|
|
|
|
llm: LLM
|
|
"""LLM wrapper to use."""
|
|
database: SQLDatabase
|
|
"""SQL Database to connect to."""
|
|
input_key: str = "query" #: :meta private:
|
|
output_key: str = "result" #: :meta private:
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
"""Return the singular input key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.input_key]
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Return the singular output key.
|
|
|
|
:meta private:
|
|
"""
|
|
return [self.output_key]
|
|
|
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
|
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
|
chained_input = ChainedInput(
|
|
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
|
|
)
|
|
llm_inputs = {
|
|
"input": chained_input.input,
|
|
"dialect": self.database.dialect,
|
|
"table_info": self.database.table_info,
|
|
"stop": ["\nSQLResult:"],
|
|
}
|
|
sql_cmd = llm_chain.predict(**llm_inputs)
|
|
chained_input.add(sql_cmd, color="green")
|
|
result = self.database.run(sql_cmd)
|
|
chained_input.add("\nSQLResult: ")
|
|
chained_input.add(result, color="yellow")
|
|
chained_input.add("\nAnswer:")
|
|
llm_inputs["input"] = chained_input.input
|
|
final_result = llm_chain.predict(**llm_inputs)
|
|
chained_input.add(final_result, color="green")
|
|
return {self.output_key: final_result}
|