community[minor]: Allow passing allow_dangerous_deserialization when loading LLM chain (#18894)

### Issue
Recently, the new `allow_dangerous_deserialization` flag was introduced
for preventing unsafe model deserialization that relies on pickle
without user's notice (#18696). Since then some LLMs like Databricks
requires passing in this flag with true to instantiate the model.

However, this breaks existing functionality to loading such LLMs within
a chain using `load_chain` method, because the underlying loader
function
[load_llm_from_config](f96dd57501/libs/langchain/langchain/chains/loading.py (L40))
 (and load_llm) ignores keyword arguments passed in. 

### Solution
This PR fixes this issue by propagating the
`allow_dangerous_deserialization` argument to the class loader iff the
LLM class has that field.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Yuki Watanabe 2024-03-27 00:07:55 +09:00 committed by GitHub
parent d7c14cb6f9
commit cfecbda48b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 130 additions and 64 deletions

View File

@ -1,15 +1,17 @@
"""Base interface for loading large language model APIs.""" """Base interface for loading large language model APIs."""
import json import json
from pathlib import Path from pathlib import Path
from typing import Union from typing import Any, Union
import yaml import yaml
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_community.llms import get_type_to_cls_dict from langchain_community.llms import get_type_to_cls_dict
_ALLOW_DANGEROUS_DESERIALIZATION_ARG = "allow_dangerous_deserialization"
def load_llm_from_config(config: dict) -> BaseLLM:
def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM:
"""Load LLM from Config Dict.""" """Load LLM from Config Dict."""
if "_type" not in config: if "_type" not in config:
raise ValueError("Must specify an LLM Type in config") raise ValueError("Must specify an LLM Type in config")
@ -21,11 +23,17 @@ def load_llm_from_config(config: dict) -> BaseLLM:
raise ValueError(f"Loading {config_type} LLM not supported") raise ValueError(f"Loading {config_type} LLM not supported")
llm_cls = type_to_cls_dict[config_type]() llm_cls = type_to_cls_dict[config_type]()
return llm_cls(**config)
load_kwargs = {}
if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__:
load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get(
_ALLOW_DANGEROUS_DESERIALIZATION_ARG, False
)
return llm_cls(**config, **load_kwargs)
def load_llm(file: Union[str, Path]) -> BaseLLM: def load_llm(file: Union[str, Path], **kwargs: Any) -> BaseLLM:
"""Load LLM from file."""
# Convert file to Path object. # Convert file to Path object.
if isinstance(file, str): if isinstance(file, str):
file_path = Path(file) file_path = Path(file)
@ -41,4 +49,4 @@ def load_llm(file: Union[str, Path]) -> BaseLLM:
else: else:
raise ValueError("File type must be json or yaml") raise ValueError("File type must be json or yaml")
# Load the LLM from the config now. # Load the LLM from the config now.
return load_llm_from_config(config) return load_llm_from_config(config, **kwargs)

View File

@ -1,4 +1,5 @@
"""test Databricks LLM""" """test Databricks LLM"""
from pathlib import Path
from typing import Any, Dict from typing import Any, Dict
import pytest import pytest
@ -8,6 +9,8 @@ from langchain_community.llms.databricks import (
Databricks, Databricks,
_load_pickled_fn_from_hex_string, _load_pickled_fn_from_hex_string,
) )
from langchain_community.llms.loading import load_llm
from tests.integration_tests.llms.utils import assert_llm_equality
class MockDatabricksServingEndpointClient: class MockDatabricksServingEndpointClient:
@ -55,3 +58,26 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None:
request = {"prompt": "What is the meaning of life?"} request = {"prompt": "What is the meaning of life?"}
fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"]) fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"])
assert fn(**request) == transform_input(**request) assert fn(**request) == transform_input(**request)
def test_saving_loading_llm(monkeypatch: MonkeyPatch, tmp_path: Path) -> None:
monkeypatch.setattr(
"langchain_community.llms.databricks._DatabricksServingEndpointClient",
MockDatabricksServingEndpointClient,
)
monkeypatch.setenv("DATABRICKS_HOST", "my-default-host")
monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token")
llm = Databricks(
endpoint_name="chat", temperature=0.1, allow_dangerous_deserialization=True
)
llm.save(file_path=tmp_path / "databricks.yaml")
# Loading without allowing_dangerous_deserialization=True should raise an error.
with pytest.raises(ValueError, match="This code relies on the pickle module."):
load_llm(tmp_path / "databricks.yaml")
loaded_llm = load_llm(
tmp_path / "databricks.yaml", allow_dangerous_deserialization=True
)
assert_llm_equality(llm, loaded_llm)

View File

@ -37,9 +37,9 @@ def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain:
"""Load LLM chain from config dict.""" """Load LLM chain from config dict."""
if "llm" in config: if "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config, **kwargs)
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"), **kwargs)
else: else:
raise ValueError("One of `llm` or `llm_path` must be present.") raise ValueError("One of `llm` or `llm_path` must be present.")
@ -59,9 +59,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde
"""Load hypothetical document embedder chain from config dict.""" """Load hypothetical document embedder chain from config dict."""
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "embeddings" in kwargs: if "embeddings" in kwargs:
@ -78,9 +78,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde
def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain: def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain:
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
@ -107,9 +107,9 @@ def _load_map_reduce_documents_chain(
) -> MapReduceDocumentsChain: ) -> MapReduceDocumentsChain:
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
@ -118,12 +118,14 @@ def _load_map_reduce_documents_chain(
if "reduce_documents_chain" in config: if "reduce_documents_chain" in config:
reduce_documents_chain = load_chain_from_config( reduce_documents_chain = load_chain_from_config(
config.pop("reduce_documents_chain") config.pop("reduce_documents_chain"), **kwargs
) )
elif "reduce_documents_chain_path" in config: elif "reduce_documents_chain_path" in config:
reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path")) reduce_documents_chain = load_chain(
config.pop("reduce_documents_chain_path"), **kwargs
)
else: else:
reduce_documents_chain = _load_reduce_documents_chain(config) reduce_documents_chain = _load_reduce_documents_chain(config, **kwargs)
return MapReduceDocumentsChain( return MapReduceDocumentsChain(
llm_chain=llm_chain, llm_chain=llm_chain,
@ -138,14 +140,22 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_document_chain_config = config.pop("combine_documents_chain") combine_document_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_document_chain_config) combine_documents_chain = load_chain_from_config(
combine_document_chain_config, **kwargs
)
elif "combine_document_chain" in config: elif "combine_document_chain" in config:
combine_document_chain_config = config.pop("combine_document_chain") combine_document_chain_config = config.pop("combine_document_chain")
combine_documents_chain = load_chain_from_config(combine_document_chain_config) combine_documents_chain = load_chain_from_config(
combine_document_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
elif "combine_document_chain_path" in config: elif "combine_document_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_document_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_document_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -158,11 +168,11 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments
collapse_documents_chain = None collapse_documents_chain = None
else: else:
collapse_documents_chain = load_chain_from_config( collapse_documents_chain = load_chain_from_config(
collapse_document_chain_config collapse_document_chain_config, **kwargs
) )
elif "collapse_documents_chain_path" in config: elif "collapse_documents_chain_path" in config:
collapse_documents_chain = load_chain( collapse_documents_chain = load_chain(
config.pop("collapse_documents_chain_path") config.pop("collapse_documents_chain_path"), **kwargs
) )
elif "collapse_document_chain" in config: elif "collapse_document_chain" in config:
collapse_document_chain_config = config.pop("collapse_document_chain") collapse_document_chain_config = config.pop("collapse_document_chain")
@ -170,11 +180,11 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments
collapse_documents_chain = None collapse_documents_chain = None
else: else:
collapse_documents_chain = load_chain_from_config( collapse_documents_chain = load_chain_from_config(
collapse_document_chain_config collapse_document_chain_config, **kwargs
) )
elif "collapse_document_chain_path" in config: elif "collapse_document_chain_path" in config:
collapse_documents_chain = load_chain( collapse_documents_chain = load_chain(
config.pop("collapse_document_chain_path") config.pop("collapse_document_chain_path"), **kwargs
) )
return ReduceDocumentsChain( return ReduceDocumentsChain(
@ -190,17 +200,17 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any:
llm_chain = None llm_chain = None
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
# llm attribute is deprecated in favor of llm_chain, here to support old configs # llm attribute is deprecated in favor of llm_chain, here to support old configs
elif "llm" in config: elif "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config, **kwargs)
# llm_path attribute is deprecated in favor of llm_chain_path, # llm_path attribute is deprecated in favor of llm_chain_path,
# its to support old configs # its to support old configs
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "prompt" in config: if "prompt" in config:
@ -217,9 +227,9 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any:
def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain:
if "llm" in config: if "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config, **kwargs)
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"), **kwargs)
else: else:
raise ValueError("One of `llm` or `llm_path` must be present.") raise ValueError("One of `llm` or `llm_path` must be present.")
if "create_draft_answer_prompt" in config: if "create_draft_answer_prompt" in config:
@ -264,17 +274,17 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain:
llm_chain = None llm_chain = None
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
# llm attribute is deprecated in favor of llm_chain, here to support old configs # llm attribute is deprecated in favor of llm_chain, here to support old configs
elif "llm" in config: elif "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config, **kwargs)
# llm_path attribute is deprecated in favor of llm_chain_path, # llm_path attribute is deprecated in favor of llm_chain_path,
# its to support old configs # its to support old configs
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "prompt" in config: if "prompt" in config:
@ -293,9 +303,9 @@ def _load_map_rerank_documents_chain(
) -> MapRerankDocumentsChain: ) -> MapRerankDocumentsChain:
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type] return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]
@ -306,9 +316,9 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
return PALChain(llm_chain=llm_chain, **config) # type: ignore[arg-type] return PALChain(llm_chain=llm_chain, **config) # type: ignore[arg-type]
@ -317,18 +327,18 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any:
def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain: def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain:
if "initial_llm_chain" in config: if "initial_llm_chain" in config:
initial_llm_chain_config = config.pop("initial_llm_chain") initial_llm_chain_config = config.pop("initial_llm_chain")
initial_llm_chain = load_chain_from_config(initial_llm_chain_config) initial_llm_chain = load_chain_from_config(initial_llm_chain_config, **kwargs)
elif "initial_llm_chain_path" in config: elif "initial_llm_chain_path" in config:
initial_llm_chain = load_chain(config.pop("initial_llm_chain_path")) initial_llm_chain = load_chain(config.pop("initial_llm_chain_path"), **kwargs)
else: else:
raise ValueError( raise ValueError(
"One of `initial_llm_chain` or `initial_llm_chain_path` must be present." "One of `initial_llm_chain` or `initial_llm_chain_path` must be present."
) )
if "refine_llm_chain" in config: if "refine_llm_chain" in config:
refine_llm_chain_config = config.pop("refine_llm_chain") refine_llm_chain_config = config.pop("refine_llm_chain")
refine_llm_chain = load_chain_from_config(refine_llm_chain_config) refine_llm_chain = load_chain_from_config(refine_llm_chain_config, **kwargs)
elif "refine_llm_chain_path" in config: elif "refine_llm_chain_path" in config:
refine_llm_chain = load_chain(config.pop("refine_llm_chain_path")) refine_llm_chain = load_chain(config.pop("refine_llm_chain_path"), **kwargs)
else: else:
raise ValueError( raise ValueError(
"One of `refine_llm_chain` or `refine_llm_chain_path` must be present." "One of `refine_llm_chain` or `refine_llm_chain_path` must be present."
@ -349,9 +359,13 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments
def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain: def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain:
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) combine_documents_chain = load_chain_from_config(
combine_documents_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -369,13 +383,13 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any:
raise ValueError("`database` must be present.") raise ValueError("`database` must be present.")
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
chain = load_chain_from_config(llm_chain_config) chain = load_chain_from_config(llm_chain_config, **kwargs, **kwargs)
return SQLDatabaseChain(llm_chain=chain, database=database, **config) # type: ignore[arg-type] return SQLDatabaseChain(llm_chain=chain, database=database, **config)
if "llm" in config: if "llm" in config:
llm_config = config.pop("llm") llm_config = config.pop("llm")
llm = load_llm_from_config(llm_config) llm = load_llm_from_config(llm_config, **kwargs)
elif "llm_path" in config: elif "llm_path" in config:
llm = load_llm(config.pop("llm_path")) llm = load_llm(config.pop("llm_path"), **kwargs)
else: else:
raise ValueError("One of `llm` or `llm_path` must be present.") raise ValueError("One of `llm` or `llm_path` must be present.")
if "prompt" in config: if "prompt" in config:
@ -396,9 +410,13 @@ def _load_vector_db_qa_with_sources_chain(
raise ValueError("`vectorstore` must be present.") raise ValueError("`vectorstore` must be present.")
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) combine_documents_chain = load_chain_from_config(
combine_documents_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -418,9 +436,13 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA:
raise ValueError("`retriever` must be present.") raise ValueError("`retriever` must be present.")
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) combine_documents_chain = load_chain_from_config(
combine_documents_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -442,9 +464,13 @@ def _load_retrieval_qa_with_sources_chain(
raise ValueError("`retriever` must be present.") raise ValueError("`retriever` must be present.")
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) combine_documents_chain = load_chain_from_config(
combine_documents_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -464,9 +490,13 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA:
raise ValueError("`vectorstore` must be present.") raise ValueError("`vectorstore` must be present.")
if "combine_documents_chain" in config: if "combine_documents_chain" in config:
combine_documents_chain_config = config.pop("combine_documents_chain") combine_documents_chain_config = config.pop("combine_documents_chain")
combine_documents_chain = load_chain_from_config(combine_documents_chain_config) combine_documents_chain = load_chain_from_config(
combine_documents_chain_config, **kwargs
)
elif "combine_documents_chain_path" in config: elif "combine_documents_chain_path" in config:
combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) combine_documents_chain = load_chain(
config.pop("combine_documents_chain_path"), **kwargs
)
else: else:
raise ValueError( raise ValueError(
"One of `combine_documents_chain` or " "One of `combine_documents_chain` or "
@ -486,12 +516,14 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
raise ValueError("`graph` must be present.") raise ValueError("`graph` must be present.")
if "cypher_generation_chain" in config: if "cypher_generation_chain" in config:
cypher_generation_chain_config = config.pop("cypher_generation_chain") cypher_generation_chain_config = config.pop("cypher_generation_chain")
cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config) cypher_generation_chain = load_chain_from_config(
cypher_generation_chain_config, **kwargs
)
else: else:
raise ValueError("`cypher_generation_chain` must be present.") raise ValueError("`cypher_generation_chain` must be present.")
if "qa_chain" in config: if "qa_chain" in config:
qa_chain_config = config.pop("qa_chain") qa_chain_config = config.pop("qa_chain")
qa_chain = load_chain_from_config(qa_chain_config) qa_chain = load_chain_from_config(qa_chain_config, **kwargs)
else: else:
raise ValueError("`qa_chain` must be present.") raise ValueError("`qa_chain` must be present.")
@ -506,7 +538,7 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain:
def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
if "api_request_chain" in config: if "api_request_chain" in config:
api_request_chain_config = config.pop("api_request_chain") api_request_chain_config = config.pop("api_request_chain")
api_request_chain = load_chain_from_config(api_request_chain_config) api_request_chain = load_chain_from_config(api_request_chain_config, **kwargs)
elif "api_request_chain_path" in config: elif "api_request_chain_path" in config:
api_request_chain = load_chain(config.pop("api_request_chain_path")) api_request_chain = load_chain(config.pop("api_request_chain_path"))
else: else:
@ -515,9 +547,9 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
) )
if "api_answer_chain" in config: if "api_answer_chain" in config:
api_answer_chain_config = config.pop("api_answer_chain") api_answer_chain_config = config.pop("api_answer_chain")
api_answer_chain = load_chain_from_config(api_answer_chain_config) api_answer_chain = load_chain_from_config(api_answer_chain_config, **kwargs)
elif "api_answer_chain_path" in config: elif "api_answer_chain_path" in config:
api_answer_chain = load_chain(config.pop("api_answer_chain_path")) api_answer_chain = load_chain(config.pop("api_answer_chain_path"), **kwargs)
else: else:
raise ValueError( raise ValueError(
"One of `api_answer_chain` or `api_answer_chain_path` must be present." "One of `api_answer_chain` or `api_answer_chain_path` must be present."
@ -537,9 +569,9 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain:
def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain:
if "llm_chain" in config: if "llm_chain" in config:
llm_chain_config = config.pop("llm_chain") llm_chain_config = config.pop("llm_chain")
llm_chain = load_chain_from_config(llm_chain_config) llm_chain = load_chain_from_config(llm_chain_config, **kwargs)
elif "llm_chain_path" in config: elif "llm_chain_path" in config:
llm_chain = load_chain(config.pop("llm_chain_path")) llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs)
else: else:
raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.")
if "requests_wrapper" in kwargs: if "requests_wrapper" in kwargs: