This commit is contained in:
Bagatur
2024-03-19 18:49:38 -07:00
parent 068620a871
commit c57e506f9c
6 changed files with 28 additions and 64 deletions

View File

@@ -71,9 +71,15 @@ def load_agent_from_config(
agent_cls = AGENT_TO_CLASS[config_type]
if "llm_chain" in config:
config["llm_chain"] = load_chain_from_config(config.pop("llm_chain"))
config["llm_chain"] = load_chain_from_config(
config.pop("llm_chain"),
load_llm_from_config=kwargs.pop("load_llm_from_config"),
)
elif "llm_chain_path" in config:
config["llm_chain"] = load_chain(config.pop("llm_chain_path"))
config["llm_chain"] = load_chain(
config.pop("llm_chain_path"),
load_llm_from_config=kwargs.pop("load_llm_from_config"),
)
else:
raise ValueError("One of `llm_chain` and `llm_chain_path` should be specified.")
if "output_parser" in config:

View File

@@ -2,34 +2,9 @@ from langchain.agents import agent_toolkits
from tests.unit_tests import assert_all_importable
EXPECTED_ALL = [
"AINetworkToolkit",
"AmadeusToolkit",
"AzureCognitiveServicesToolkit",
"FileManagementToolkit",
"GmailToolkit",
"JiraToolkit",
"JsonToolkit",
"MultionToolkit",
"NasaToolkit",
"NLAToolkit",
"O365Toolkit",
"OpenAPIToolkit",
"PlayWrightBrowserToolkit",
"PowerBIToolkit",
"SlackToolkit",
"SteamToolkit",
"SQLDatabaseToolkit",
"SparkSQLToolkit",
"VectorStoreInfo",
"VectorStoreRouterToolkit",
"VectorStoreToolkit",
"ZapierToolkit",
"create_json_agent",
"create_openapi_agent",
"create_pbi_agent",
"create_pbi_chat_agent",
"create_spark_sql_agent",
"create_sql_agent",
"create_vectorstore_agent",
"create_vectorstore_router_agent",
"create_conversational_retrieval_agent",

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from langchain_community.llms.fake import FakeListLLM
from langchain_core.tools import Tool
@@ -8,7 +9,10 @@ from langchain.agents.agent_types import AgentType
from langchain.agents.initialize import initialize_agent, load_agent
@pytest.mark.requires("langchain_community")
def test_mrkl_serialization() -> None:
from langchain_community.llms.loading import load_llm_from_config
agent = initialize_agent(
[
Tool(
@@ -24,4 +28,4 @@ def test_mrkl_serialization() -> None:
with TemporaryDirectory() as tempdir:
file = Path(tempdir) / "agent.json"
agent.save_agent(file)
load_agent(file)
load_agent(file, load_llm_from_config=load_llm_from_config)

View File

@@ -2,38 +2,15 @@ from langchain import callbacks
from tests.unit_tests import assert_all_importable
EXPECTED_ALL = [
"AimCallbackHandler",
"ArgillaCallbackHandler",
"ArizeCallbackHandler",
"PromptLayerCallbackHandler",
"ArthurCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"ContextCallbackHandler",
"FileCallbackHandler",
"HumanApprovalCallbackHandler",
"InfinoCallbackHandler",
"MlflowCallbackHandler",
"LLMonitorCallbackHandler",
"OpenAICallbackHandler",
"StdOutCallbackHandler",
"AsyncIteratorCallbackHandler",
"StreamingStdOutCallbackHandler",
"FinalStreamingStdOutCallbackHandler",
"LLMThoughtLabeler",
"LangChainTracer",
"StreamlitCallbackHandler",
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled",
"FlyteCallbackHandler",
"SageMakerCallbackHandler",
"LabelStudioCallbackHandler",
"TrubricsCallbackHandler",
]

View File

@@ -1,13 +1,14 @@
"""Test LLM chain."""
from tempfile import TemporaryDirectory
from typing import Dict, List, Union
from unittest.mock import patch
import pytest
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.loading import load_chain
from tests.unit_tests.llms.fake_llm import FakeLLM
@@ -26,18 +27,16 @@ def fake_llm_chain() -> LLMChain:
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
@patch(
"langchain_community.llms.loading.get_type_to_cls_dict",
lambda: {"fake": lambda: FakeLLM},
)
def test_serialization(fake_llm_chain: LLMChain) -> None:
"""Test serialization."""
from langchain.chains.loading import load_chain
def load_llm_from_config(config: dict) -> BaseLanguageModel:
return FakeLLM()
with TemporaryDirectory() as temp_dir:
file = temp_dir + "/llm.json"
fake_llm_chain.save(file)
loaded_chain = load_chain(file)
loaded_chain = load_chain(file, load_llm_from_config=load_llm_from_config)
assert loaded_chain == fake_llm_chain

View File

@@ -1,7 +1,8 @@
from langchain import chat_models
from tests.unit_tests import assert_all_importable
import pytest
EXPECTED_ALL = [
from langchain import chat_models
EXPECTED_DEPRECATED_IMPORTS = [
"ChatOpenAI",
"BedrockChat",
"AzureChatOpenAI",
@@ -35,6 +36,8 @@ EXPECTED_ALL = [
]
def test_all_imports() -> None:
assert set(chat_models.__all__) == set(EXPECTED_ALL)
assert_all_importable(chat_models)
def test_deprecated_imports() -> None:
for import_ in EXPECTED_DEPRECATED_IMPORTS:
with pytest.raises(ImportError) as e:
getattr(chat_models, import_)
assert "langchain_community" in e