mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 02:58:06 +00:00
parent
4074ea4c41
commit
1bc3244db9
@ -358,10 +358,16 @@ def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesCha
|
||||
|
||||
|
||||
def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
||||
from langchain_experimental.sql import SQLDatabaseChain
|
||||
|
||||
if "database" in kwargs:
|
||||
database = kwargs.pop("database")
|
||||
else:
|
||||
raise ValueError("`database` must be present.")
|
||||
if "llm_chain" in config:
|
||||
llm_chain_config = config.pop("llm_chain")
|
||||
chain = load_chain_from_config(llm_chain_config)
|
||||
return SQLDatabaseChain(llm_chain=chain, database=database, **config)
|
||||
if "llm" in config:
|
||||
llm_config = config.pop("llm")
|
||||
llm = load_llm_from_config(llm_config)
|
||||
@ -374,7 +380,6 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
|
||||
prompt = load_prompt_from_config(prompt_config)
|
||||
else:
|
||||
prompt = None
|
||||
from langchain_experimental.sql import SQLDatabaseChain
|
||||
|
||||
return SQLDatabaseChain.from_llm(llm, database, prompt=prompt, **config)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user