diff --git a/libs/community/langchain_community/llms/loading.py b/libs/community/langchain_community/llms/loading.py index f410f7b7c6c..83a459265ed 100644 --- a/libs/community/langchain_community/llms/loading.py +++ b/libs/community/langchain_community/llms/loading.py @@ -1,15 +1,17 @@ """Base interface for loading large language model APIs.""" import json from pathlib import Path -from typing import Union +from typing import Any, Union import yaml from langchain_core.language_models.llms import BaseLLM from langchain_community.llms import get_type_to_cls_dict +_ALLOW_DANGEROUS_DESERIALIZATION_ARG = "allow_dangerous_deserialization" -def load_llm_from_config(config: dict) -> BaseLLM: + +def load_llm_from_config(config: dict, **kwargs: Any) -> BaseLLM: """Load LLM from Config Dict.""" if "_type" not in config: raise ValueError("Must specify an LLM Type in config") @@ -21,11 +23,17 @@ def load_llm_from_config(config: dict) -> BaseLLM: raise ValueError(f"Loading {config_type} LLM not supported") llm_cls = type_to_cls_dict[config_type]() - return llm_cls(**config) + + load_kwargs = {} + if _ALLOW_DANGEROUS_DESERIALIZATION_ARG in llm_cls.__fields__: + load_kwargs[_ALLOW_DANGEROUS_DESERIALIZATION_ARG] = kwargs.get( + _ALLOW_DANGEROUS_DESERIALIZATION_ARG, False + ) + + return llm_cls(**config, **load_kwargs) -def load_llm(file: Union[str, Path]) -> BaseLLM: - """Load LLM from file.""" +def load_llm(file: Union[str, Path], **kwargs: Any) -> BaseLLM: # Convert file to Path object. if isinstance(file, str): file_path = Path(file) @@ -41,4 +49,4 @@ def load_llm(file: Union[str, Path]) -> BaseLLM: else: raise ValueError("File type must be json or yaml") # Load the LLM from the config now. - return load_llm_from_config(config) + return load_llm_from_config(config, **kwargs) diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py index fe93101d688..640f274a762 100644 --- a/libs/community/tests/unit_tests/llms/test_databricks.py +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -1,4 +1,5 @@ """test Databricks LLM""" +from pathlib import Path from typing import Any, Dict import pytest @@ -8,6 +9,8 @@ from langchain_community.llms.databricks import ( Databricks, _load_pickled_fn_from_hex_string, ) +from langchain_community.llms.loading import load_llm +from tests.integration_tests.llms.utils import assert_llm_equality class MockDatabricksServingEndpointClient: @@ -55,3 +58,26 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: request = {"prompt": "What is the meaning of life?"} fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"]) assert fn(**request) == transform_input(**request) + + +def test_saving_loading_llm(monkeypatch: MonkeyPatch, tmp_path: Path) -> None: + monkeypatch.setattr( + "langchain_community.llms.databricks._DatabricksServingEndpointClient", + MockDatabricksServingEndpointClient, + ) + monkeypatch.setenv("DATABRICKS_HOST", "my-default-host") + monkeypatch.setenv("DATABRICKS_TOKEN", "my-default-token") + + llm = Databricks( + endpoint_name="chat", temperature=0.1, allow_dangerous_deserialization=True + ) + llm.save(file_path=tmp_path / "databricks.yaml") + + # Loading without allowing_dangerous_deserialization=True should raise an error. + with pytest.raises(ValueError, match="This code relies on the pickle module."): + load_llm(tmp_path / "databricks.yaml") + + loaded_llm = load_llm( + tmp_path / "databricks.yaml", allow_dangerous_deserialization=True + ) + assert_llm_equality(llm, loaded_llm) diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index a5c3e458188..32921145b1c 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -37,9 +37,9 @@ def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain: """Load LLM chain from config dict.""" if "llm" in config: llm_config = config.pop("llm") - llm = load_llm_from_config(llm_config) + llm = load_llm_from_config(llm_config, **kwargs) elif "llm_path" in config: - llm = load_llm(config.pop("llm_path")) + llm = load_llm(config.pop("llm_path"), **kwargs) else: raise ValueError("One of `llm` or `llm_path` must be present.") @@ -59,9 +59,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde """Load hypothetical document embedder chain from config dict.""" if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "embeddings" in kwargs: @@ -78,9 +78,9 @@ def _load_hyde_chain(config: dict, **kwargs: Any) -> HypotheticalDocumentEmbedde def _load_stuff_documents_chain(config: dict, **kwargs: Any) -> StuffDocumentsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") @@ -107,9 +107,9 @@ def _load_map_reduce_documents_chain( ) -> MapReduceDocumentsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") @@ -118,12 +118,14 @@ def _load_map_reduce_documents_chain( if "reduce_documents_chain" in config: reduce_documents_chain = load_chain_from_config( - config.pop("reduce_documents_chain") + config.pop("reduce_documents_chain"), **kwargs ) elif "reduce_documents_chain_path" in config: - reduce_documents_chain = load_chain(config.pop("reduce_documents_chain_path")) + reduce_documents_chain = load_chain( + config.pop("reduce_documents_chain_path"), **kwargs + ) else: - reduce_documents_chain = _load_reduce_documents_chain(config) + reduce_documents_chain = _load_reduce_documents_chain(config, **kwargs) return MapReduceDocumentsChain( llm_chain=llm_chain, @@ -138,14 +140,22 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments if "combine_documents_chain" in config: combine_document_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_document_chain_config) + combine_documents_chain = load_chain_from_config( + combine_document_chain_config, **kwargs + ) elif "combine_document_chain" in config: combine_document_chain_config = config.pop("combine_document_chain") - combine_documents_chain = load_chain_from_config(combine_document_chain_config) + combine_documents_chain = load_chain_from_config( + combine_document_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) elif "combine_document_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_document_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_document_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -158,11 +168,11 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments collapse_documents_chain = None else: collapse_documents_chain = load_chain_from_config( - collapse_document_chain_config + collapse_document_chain_config, **kwargs ) elif "collapse_documents_chain_path" in config: collapse_documents_chain = load_chain( - config.pop("collapse_documents_chain_path") + config.pop("collapse_documents_chain_path"), **kwargs ) elif "collapse_document_chain" in config: collapse_document_chain_config = config.pop("collapse_document_chain") @@ -170,11 +180,11 @@ def _load_reduce_documents_chain(config: dict, **kwargs: Any) -> ReduceDocuments collapse_documents_chain = None else: collapse_documents_chain = load_chain_from_config( - collapse_document_chain_config + collapse_document_chain_config, **kwargs ) elif "collapse_document_chain_path" in config: collapse_documents_chain = load_chain( - config.pop("collapse_document_chain_path") + config.pop("collapse_document_chain_path"), **kwargs ) return ReduceDocumentsChain( @@ -190,17 +200,17 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any: llm_chain = None if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) # llm attribute is deprecated in favor of llm_chain, here to support old configs elif "llm" in config: llm_config = config.pop("llm") - llm = load_llm_from_config(llm_config) + llm = load_llm_from_config(llm_config, **kwargs) # llm_path attribute is deprecated in favor of llm_chain_path, # its to support old configs elif "llm_path" in config: - llm = load_llm(config.pop("llm_path")) + llm = load_llm(config.pop("llm_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "prompt" in config: @@ -217,9 +227,9 @@ def _load_llm_bash_chain(config: dict, **kwargs: Any) -> Any: def _load_llm_checker_chain(config: dict, **kwargs: Any) -> LLMCheckerChain: if "llm" in config: llm_config = config.pop("llm") - llm = load_llm_from_config(llm_config) + llm = load_llm_from_config(llm_config, **kwargs) elif "llm_path" in config: - llm = load_llm(config.pop("llm_path")) + llm = load_llm(config.pop("llm_path"), **kwargs) else: raise ValueError("One of `llm` or `llm_path` must be present.") if "create_draft_answer_prompt" in config: @@ -264,17 +274,17 @@ def _load_llm_math_chain(config: dict, **kwargs: Any) -> LLMMathChain: llm_chain = None if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) # llm attribute is deprecated in favor of llm_chain, here to support old configs elif "llm" in config: llm_config = config.pop("llm") - llm = load_llm_from_config(llm_config) + llm = load_llm_from_config(llm_config, **kwargs) # llm_path attribute is deprecated in favor of llm_chain_path, # its to support old configs elif "llm_path" in config: - llm = load_llm(config.pop("llm_path")) + llm = load_llm(config.pop("llm_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "prompt" in config: @@ -293,9 +303,9 @@ def _load_map_rerank_documents_chain( ) -> MapRerankDocumentsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") return MapRerankDocumentsChain(llm_chain=llm_chain, **config) # type: ignore[arg-type] @@ -306,9 +316,9 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") return PALChain(llm_chain=llm_chain, **config) # type: ignore[arg-type] @@ -317,18 +327,18 @@ def _load_pal_chain(config: dict, **kwargs: Any) -> Any: def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocumentsChain: if "initial_llm_chain" in config: initial_llm_chain_config = config.pop("initial_llm_chain") - initial_llm_chain = load_chain_from_config(initial_llm_chain_config) + initial_llm_chain = load_chain_from_config(initial_llm_chain_config, **kwargs) elif "initial_llm_chain_path" in config: - initial_llm_chain = load_chain(config.pop("initial_llm_chain_path")) + initial_llm_chain = load_chain(config.pop("initial_llm_chain_path"), **kwargs) else: raise ValueError( "One of `initial_llm_chain` or `initial_llm_chain_path` must be present." ) if "refine_llm_chain" in config: refine_llm_chain_config = config.pop("refine_llm_chain") - refine_llm_chain = load_chain_from_config(refine_llm_chain_config) + refine_llm_chain = load_chain_from_config(refine_llm_chain_config, **kwargs) elif "refine_llm_chain_path" in config: - refine_llm_chain = load_chain(config.pop("refine_llm_chain_path")) + refine_llm_chain = load_chain(config.pop("refine_llm_chain_path"), **kwargs) else: raise ValueError( "One of `refine_llm_chain` or `refine_llm_chain_path` must be present." @@ -349,9 +359,13 @@ def _load_refine_documents_chain(config: dict, **kwargs: Any) -> RefineDocuments def _load_qa_with_sources_chain(config: dict, **kwargs: Any) -> QAWithSourcesChain: if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + combine_documents_chain = load_chain_from_config( + combine_documents_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -369,13 +383,13 @@ def _load_sql_database_chain(config: dict, **kwargs: Any) -> Any: raise ValueError("`database` must be present.") if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - chain = load_chain_from_config(llm_chain_config) - return SQLDatabaseChain(llm_chain=chain, database=database, **config) # type: ignore[arg-type] + chain = load_chain_from_config(llm_chain_config, **kwargs, **kwargs) + return SQLDatabaseChain(llm_chain=chain, database=database, **config) if "llm" in config: llm_config = config.pop("llm") - llm = load_llm_from_config(llm_config) + llm = load_llm_from_config(llm_config, **kwargs) elif "llm_path" in config: - llm = load_llm(config.pop("llm_path")) + llm = load_llm(config.pop("llm_path"), **kwargs) else: raise ValueError("One of `llm` or `llm_path` must be present.") if "prompt" in config: @@ -396,9 +410,13 @@ def _load_vector_db_qa_with_sources_chain( raise ValueError("`vectorstore` must be present.") if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + combine_documents_chain = load_chain_from_config( + combine_documents_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -418,9 +436,13 @@ def _load_retrieval_qa(config: dict, **kwargs: Any) -> RetrievalQA: raise ValueError("`retriever` must be present.") if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + combine_documents_chain = load_chain_from_config( + combine_documents_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -442,9 +464,13 @@ def _load_retrieval_qa_with_sources_chain( raise ValueError("`retriever` must be present.") if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + combine_documents_chain = load_chain_from_config( + combine_documents_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -464,9 +490,13 @@ def _load_vector_db_qa(config: dict, **kwargs: Any) -> VectorDBQA: raise ValueError("`vectorstore` must be present.") if "combine_documents_chain" in config: combine_documents_chain_config = config.pop("combine_documents_chain") - combine_documents_chain = load_chain_from_config(combine_documents_chain_config) + combine_documents_chain = load_chain_from_config( + combine_documents_chain_config, **kwargs + ) elif "combine_documents_chain_path" in config: - combine_documents_chain = load_chain(config.pop("combine_documents_chain_path")) + combine_documents_chain = load_chain( + config.pop("combine_documents_chain_path"), **kwargs + ) else: raise ValueError( "One of `combine_documents_chain` or " @@ -486,12 +516,14 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: raise ValueError("`graph` must be present.") if "cypher_generation_chain" in config: cypher_generation_chain_config = config.pop("cypher_generation_chain") - cypher_generation_chain = load_chain_from_config(cypher_generation_chain_config) + cypher_generation_chain = load_chain_from_config( + cypher_generation_chain_config, **kwargs + ) else: raise ValueError("`cypher_generation_chain` must be present.") if "qa_chain" in config: qa_chain_config = config.pop("qa_chain") - qa_chain = load_chain_from_config(qa_chain_config) + qa_chain = load_chain_from_config(qa_chain_config, **kwargs) else: raise ValueError("`qa_chain` must be present.") @@ -506,7 +538,7 @@ def _load_graph_cypher_chain(config: dict, **kwargs: Any) -> GraphCypherQAChain: def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: if "api_request_chain" in config: api_request_chain_config = config.pop("api_request_chain") - api_request_chain = load_chain_from_config(api_request_chain_config) + api_request_chain = load_chain_from_config(api_request_chain_config, **kwargs) elif "api_request_chain_path" in config: api_request_chain = load_chain(config.pop("api_request_chain_path")) else: @@ -515,9 +547,9 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: ) if "api_answer_chain" in config: api_answer_chain_config = config.pop("api_answer_chain") - api_answer_chain = load_chain_from_config(api_answer_chain_config) + api_answer_chain = load_chain_from_config(api_answer_chain_config, **kwargs) elif "api_answer_chain_path" in config: - api_answer_chain = load_chain(config.pop("api_answer_chain_path")) + api_answer_chain = load_chain(config.pop("api_answer_chain_path"), **kwargs) else: raise ValueError( "One of `api_answer_chain` or `api_answer_chain_path` must be present." @@ -537,9 +569,9 @@ def _load_api_chain(config: dict, **kwargs: Any) -> APIChain: def _load_llm_requests_chain(config: dict, **kwargs: Any) -> LLMRequestsChain: if "llm_chain" in config: llm_chain_config = config.pop("llm_chain") - llm_chain = load_chain_from_config(llm_chain_config) + llm_chain = load_chain_from_config(llm_chain_config, **kwargs) elif "llm_chain_path" in config: - llm_chain = load_chain(config.pop("llm_chain_path")) + llm_chain = load_chain(config.pop("llm_chain_path"), **kwargs) else: raise ValueError("One of `llm_chain` or `llm_chain_path` must be present.") if "requests_wrapper" in kwargs: