mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 22:56:05 +00:00
fmt
This commit is contained in:
@@ -71,9 +71,15 @@ def load_agent_from_config(
|
||||
|
||||
agent_cls = AGENT_TO_CLASS[config_type]
|
||||
if "llm_chain" in config:
|
||||
config["llm_chain"] = load_chain_from_config(config.pop("llm_chain"))
|
||||
config["llm_chain"] = load_chain_from_config(
|
||||
config.pop("llm_chain"),
|
||||
load_llm_from_config=kwargs.pop("load_llm_from_config"),
|
||||
)
|
||||
elif "llm_chain_path" in config:
|
||||
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
|
||||
config["llm_chain"] = load_chain(
|
||||
config.pop("llm_chain_path"),
|
||||
load_llm_from_config=kwargs.pop("load_llm_from_config"),
|
||||
)
|
||||
else:
|
||||
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
|
||||
if "output_parser" in config:
|
||||
|
||||
@@ -2,34 +2,9 @@ from langchain.agents import agent_toolkits
|
||||
from tests.unit_tests import assert_all_importable
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AINetworkToolkit",
|
||||
"AmadeusToolkit",
|
||||
"AzureCognitiveServicesToolkit",
|
||||
"FileManagementToolkit",
|
||||
"GmailToolkit",
|
||||
"JiraToolkit",
|
||||
"JsonToolkit",
|
||||
"MultionToolkit",
|
||||
"NasaToolkit",
|
||||
"NLAToolkit",
|
||||
"O365Toolkit",
|
||||
"OpenAPIToolkit",
|
||||
"PlayWrightBrowserToolkit",
|
||||
"PowerBIToolkit",
|
||||
"SlackToolkit",
|
||||
"SteamToolkit",
|
||||
"SQLDatabaseToolkit",
|
||||
"SparkSQLToolkit",
|
||||
"VectorStoreInfo",
|
||||
"VectorStoreRouterToolkit",
|
||||
"VectorStoreToolkit",
|
||||
"ZapierToolkit",
|
||||
"create_json_agent",
|
||||
"create_openapi_agent",
|
||||
"create_pbi_agent",
|
||||
"create_pbi_chat_agent",
|
||||
"create_spark_sql_agent",
|
||||
"create_sql_agent",
|
||||
"create_vectorstore_agent",
|
||||
"create_vectorstore_router_agent",
|
||||
"create_conversational_retrieval_agent",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
from langchain_community.llms.fake import FakeListLLM
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
@@ -8,7 +9,10 @@ from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.initialize import initialize_agent, load_agent
|
||||
|
||||
|
||||
@pytest.mark.requires("langchain_community")
|
||||
def test_mrkl_serialization() -> None:
|
||||
from langchain_community.llms.loading import load_llm_from_config
|
||||
|
||||
agent = initialize_agent(
|
||||
[
|
||||
Tool(
|
||||
@@ -24,4 +28,4 @@ def test_mrkl_serialization() -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
file = Path(tempdir) / "agent.json"
|
||||
agent.save_agent(file)
|
||||
load_agent(file)
|
||||
load_agent(file, load_llm_from_config=load_llm_from_config)
|
||||
|
||||
@@ -2,38 +2,15 @@ from langchain import callbacks
|
||||
from tests.unit_tests import assert_all_importable
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"AimCallbackHandler",
|
||||
"ArgillaCallbackHandler",
|
||||
"ArizeCallbackHandler",
|
||||
"PromptLayerCallbackHandler",
|
||||
"ArthurCallbackHandler",
|
||||
"ClearMLCallbackHandler",
|
||||
"CometCallbackHandler",
|
||||
"ContextCallbackHandler",
|
||||
"FileCallbackHandler",
|
||||
"HumanApprovalCallbackHandler",
|
||||
"InfinoCallbackHandler",
|
||||
"MlflowCallbackHandler",
|
||||
"LLMonitorCallbackHandler",
|
||||
"OpenAICallbackHandler",
|
||||
"StdOutCallbackHandler",
|
||||
"AsyncIteratorCallbackHandler",
|
||||
"StreamingStdOutCallbackHandler",
|
||||
"FinalStreamingStdOutCallbackHandler",
|
||||
"LLMThoughtLabeler",
|
||||
"LangChainTracer",
|
||||
"StreamlitCallbackHandler",
|
||||
"WandbCallbackHandler",
|
||||
"WhyLabsCallbackHandler",
|
||||
"get_openai_callback",
|
||||
"tracing_enabled",
|
||||
"tracing_v2_enabled",
|
||||
"collect_runs",
|
||||
"wandb_tracing_enabled",
|
||||
"FlyteCallbackHandler",
|
||||
"SageMakerCallbackHandler",
|
||||
"LabelStudioCallbackHandler",
|
||||
"TrubricsCallbackHandler",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Test LLM chain."""
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Dict, List, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.loading import load_chain
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@@ -26,18 +27,16 @@ def fake_llm_chain() -> LLMChain:
|
||||
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
|
||||
|
||||
|
||||
@patch(
|
||||
"langchain_community.llms.loading.get_type_to_cls_dict",
|
||||
lambda: {"fake": lambda: FakeLLM},
|
||||
)
|
||||
def test_serialization(fake_llm_chain: LLMChain) -> None:
|
||||
"""Test serialization."""
|
||||
from langchain.chains.loading import load_chain
|
||||
|
||||
def load_llm_from_config(config: dict) -> BaseLanguageModel:
|
||||
return FakeLLM()
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
file = temp_dir + "/llm.json"
|
||||
fake_llm_chain.save(file)
|
||||
loaded_chain = load_chain(file)
|
||||
loaded_chain = load_chain(file, load_llm_from_config=load_llm_from_config)
|
||||
assert loaded_chain == fake_llm_chain
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from langchain import chat_models
|
||||
from tests.unit_tests import assert_all_importable
|
||||
import pytest
|
||||
|
||||
EXPECTED_ALL = [
|
||||
from langchain import chat_models
|
||||
|
||||
EXPECTED_DEPRECATED_IMPORTS = [
|
||||
"ChatOpenAI",
|
||||
"BedrockChat",
|
||||
"AzureChatOpenAI",
|
||||
@@ -35,6 +36,8 @@ EXPECTED_ALL = [
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert set(chat_models.__all__) == set(EXPECTED_ALL)
|
||||
assert_all_importable(chat_models)
|
||||
def test_deprecated_imports() -> None:
|
||||
for import_ in EXPECTED_DEPRECATED_IMPORTS:
|
||||
with pytest.raises(ImportError) as e:
|
||||
getattr(chat_models, import_)
|
||||
assert "langchain_community" in e
|
||||
|
||||
Reference in New Issue
Block a user