From c57e506f9cd6c8b5043c19ee937d4822770100a2 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 19 Mar 2024 18:49:38 -0700 Subject: [PATCH] fmt --- libs/langchain/langchain/agents/loading.py | 10 ++++++-- .../agents/agent_toolkits/test_imports.py | 25 ------------------- .../unit_tests/agents/test_serialization.py | 6 ++++- .../unit_tests/callbacks/test_imports.py | 23 ----------------- .../tests/unit_tests/chains/test_llm.py | 13 +++++----- .../unit_tests/chat_models/test_imports.py | 15 ++++++----- 6 files changed, 28 insertions(+), 64 deletions(-) diff --git a/libs/langchain/langchain/agents/loading.py b/libs/langchain/langchain/agents/loading.py index e1d9747df2d..edecb4e6b9c 100644 --- a/libs/langchain/langchain/agents/loading.py +++ b/libs/langchain/langchain/agents/loading.py @@ -71,9 +71,15 @@ def load_agent_from_config( agent_cls = AGENT_TO_CLASS[config_type] if "llm_chain" in config: - config["llm_chain"] = load_chain_from_config(config.pop("llm_chain")) + config["llm_chain"] = load_chain_from_config( + config.pop("llm_chain"), + load_llm_from_config=kwargs.pop("load_llm_from_config"), + ) elif "llm_chain_path" in config: - config["llm_chain"] = load_chain(config.pop("llm_chain_path")) + config["llm_chain"] = load_chain( + config.pop("llm_chain_path"), + load_llm_from_config=kwargs.pop("load_llm_from_config"), + ) else: raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.") if "output_parser" in config: diff --git a/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py b/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py index c81ce791ea2..ade58982102 100644 --- a/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py +++ b/libs/langchain/tests/unit_tests/agents/agent_toolkits/test_imports.py @@ -2,34 +2,9 @@ from langchain.agents import agent_toolkits from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ - "AINetworkToolkit", - "AmadeusToolkit", - "AzureCognitiveServicesToolkit", - "FileManagementToolkit", - "GmailToolkit", - "JiraToolkit", - "JsonToolkit", - "MultionToolkit", - "NasaToolkit", - "NLAToolkit", - "O365Toolkit", - "OpenAPIToolkit", - "PlayWrightBrowserToolkit", - "PowerBIToolkit", - "SlackToolkit", - "SteamToolkit", - "SQLDatabaseToolkit", - "SparkSQLToolkit", "VectorStoreInfo", "VectorStoreRouterToolkit", "VectorStoreToolkit", - "ZapierToolkit", - "create_json_agent", - "create_openapi_agent", - "create_pbi_agent", - "create_pbi_chat_agent", - "create_spark_sql_agent", - "create_sql_agent", "create_vectorstore_agent", "create_vectorstore_router_agent", "create_conversational_retrieval_agent", diff --git a/libs/langchain/tests/unit_tests/agents/test_serialization.py b/libs/langchain/tests/unit_tests/agents/test_serialization.py index edba5475380..4f423a99181 100644 --- a/libs/langchain/tests/unit_tests/agents/test_serialization.py +++ b/libs/langchain/tests/unit_tests/agents/test_serialization.py @@ -1,6 +1,7 @@ from pathlib import Path from tempfile import TemporaryDirectory +import pytest from langchain_community.llms.fake import FakeListLLM from langchain_core.tools import Tool @@ -8,7 +9,10 @@ from langchain.agents.agent_types import AgentType from langchain.agents.initialize import initialize_agent, load_agent +@pytest.mark.requires("langchain_community") def test_mrkl_serialization() -> None: + from langchain_community.llms.loading import load_llm_from_config + agent = initialize_agent( [ Tool( @@ -24,4 +28,4 @@ def test_mrkl_serialization() -> None: with TemporaryDirectory() as tempdir: file = Path(tempdir) / "agent.json" agent.save_agent(file) - load_agent(file) + load_agent(file, load_llm_from_config=load_llm_from_config) diff --git a/libs/langchain/tests/unit_tests/callbacks/test_imports.py b/libs/langchain/tests/unit_tests/callbacks/test_imports.py index 3e01ae49534..ebd79934f30 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_imports.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_imports.py @@ -2,38 +2,15 @@ from langchain import callbacks from tests.unit_tests import assert_all_importable EXPECTED_ALL = [ - "AimCallbackHandler", - "ArgillaCallbackHandler", - "ArizeCallbackHandler", - "PromptLayerCallbackHandler", - "ArthurCallbackHandler", - "ClearMLCallbackHandler", - "CometCallbackHandler", - "ContextCallbackHandler", "FileCallbackHandler", - "HumanApprovalCallbackHandler", - "InfinoCallbackHandler", - "MlflowCallbackHandler", - "LLMonitorCallbackHandler", - "OpenAICallbackHandler", "StdOutCallbackHandler", "AsyncIteratorCallbackHandler", "StreamingStdOutCallbackHandler", "FinalStreamingStdOutCallbackHandler", - "LLMThoughtLabeler", "LangChainTracer", - "StreamlitCallbackHandler", - "WandbCallbackHandler", - "WhyLabsCallbackHandler", - "get_openai_callback", "tracing_enabled", "tracing_v2_enabled", "collect_runs", - "wandb_tracing_enabled", - "FlyteCallbackHandler", - "SageMakerCallbackHandler", - "LabelStudioCallbackHandler", - "TrubricsCallbackHandler", ] diff --git a/libs/langchain/tests/unit_tests/chains/test_llm.py b/libs/langchain/tests/unit_tests/chains/test_llm.py index 0179cd135f2..0e50fbb671d 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm.py @@ -1,13 +1,14 @@ """Test LLM chain.""" from tempfile import TemporaryDirectory from typing import Dict, List, Union -from unittest.mock import patch import pytest +from langchain_core.language_models import BaseLanguageModel from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import PromptTemplate from langchain.chains.llm import LLMChain +from langchain.chains.loading import load_chain from tests.unit_tests.llms.fake_llm import FakeLLM @@ -26,18 +27,16 @@ def fake_llm_chain() -> LLMChain: return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") -@patch( - "langchain_community.llms.loading.get_type_to_cls_dict", - lambda: {"fake": lambda: FakeLLM}, -) def test_serialization(fake_llm_chain: LLMChain) -> None: """Test serialization.""" - from langchain.chains.loading import load_chain + + def load_llm_from_config(config: dict) -> BaseLanguageModel: + return FakeLLM() with TemporaryDirectory() as temp_dir: file = temp_dir + "/llm.json" fake_llm_chain.save(file) - loaded_chain = load_chain(file) + loaded_chain = load_chain(file, load_llm_from_config=load_llm_from_config) assert loaded_chain == fake_llm_chain diff --git a/libs/langchain/tests/unit_tests/chat_models/test_imports.py b/libs/langchain/tests/unit_tests/chat_models/test_imports.py index e27df46d555..c4db92357bf 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_imports.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_imports.py @@ -1,7 +1,8 @@ -from langchain import chat_models -from tests.unit_tests import assert_all_importable +import pytest -EXPECTED_ALL = [ +from langchain import chat_models + +EXPECTED_DEPRECATED_IMPORTS = [ "ChatOpenAI", "BedrockChat", "AzureChatOpenAI", @@ -35,6 +36,8 @@ EXPECTED_ALL = [ ] -def test_all_imports() -> None: - assert set(chat_models.__all__) == set(EXPECTED_ALL) - assert_all_importable(chat_models) +def test_deprecated_imports() -> None: + for import_ in EXPECTED_DEPRECATED_IMPORTS: + with pytest.raises(ImportError) as e: + getattr(chat_models, import_) + assert "langchain_community" in e