langchain[patch],community[patch]: Move unit tests that depend on community to community (#21685)

This commit is contained in:
Eugene Yurtsev
2024-05-16 17:24:27 -04:00
committed by GitHub
parent 97a4ae50d2
commit 8607735b80
22 changed files with 1248 additions and 262 deletions

View File

@@ -23,7 +23,7 @@ from langchain._api import create_importer
_module_lookup = {
"APIChain": "langchain.chains.api.base",
"OpenAPIEndpointChain": "langchain.chains.api.openapi.chain",
"OpenAPIEndpointChain": "langchain_community.chains.openapi.chain",
"AnalyzeDocumentChain": "langchain.chains.combine_documents.base",
"MapReduceDocumentsChain": "langchain.chains.combine_documents.map_reduce",
"MapRerankDocumentsChain": "langchain.chains.combine_documents.map_rerank",
@@ -36,23 +36,23 @@ _module_lookup = {
"ConversationalRetrievalChain": "langchain.chains.conversational_retrieval.base",
"generate_example": "langchain.chains.example_generator",
"FlareChain": "langchain.chains.flare.base",
"ArangoGraphQAChain": "langchain.chains.graph_qa.arangodb",
"GraphQAChain": "langchain.chains.graph_qa.base",
"GraphCypherQAChain": "langchain.chains.graph_qa.cypher",
"FalkorDBQAChain": "langchain.chains.graph_qa.falkordb",
"HugeGraphQAChain": "langchain.chains.graph_qa.hugegraph",
"KuzuQAChain": "langchain.chains.graph_qa.kuzu",
"NebulaGraphQAChain": "langchain.chains.graph_qa.nebulagraph",
"NeptuneOpenCypherQAChain": "langchain.chains.graph_qa.neptune_cypher",
"NeptuneSparqlQAChain": "langchain.chains.graph_qa.neptune_sparql",
"OntotextGraphDBQAChain": "langchain.chains.graph_qa.ontotext_graphdb",
"GraphSparqlQAChain": "langchain.chains.graph_qa.sparql",
"ArangoGraphQAChain": "langchain_community.chains.graph_qa.arangodb",
"GraphQAChain": "langchain_community.chains.graph_qa.base",
"GraphCypherQAChain": "langchain_community.chains.graph_qa.cypher",
"FalkorDBQAChain": "langchain_community.chains.graph_qa.falkordb",
"HugeGraphQAChain": "langchain_community.chains.graph_qa.hugegraph",
"KuzuQAChain": "langchain_community.chains.graph_qa.kuzu",
"NebulaGraphQAChain": "langchain_community.chains.graph_qa.nebulagraph",
"NeptuneOpenCypherQAChain": "langchain_community.chains.graph_qa.neptune_cypher",
"NeptuneSparqlQAChain": "langchain_community.chains.graph_qa.neptune_sparql",
"OntotextGraphDBQAChain": "langchain_community.chains.graph_qa.ontotext_graphdb",
"GraphSparqlQAChain": "langchain_community.chains.graph_qa.sparql",
"create_history_aware_retriever": "langchain.chains.history_aware_retriever",
"HypotheticalDocumentEmbedder": "langchain.chains.hyde.base",
"LLMChain": "langchain.chains.llm",
"LLMCheckerChain": "langchain.chains.llm_checker.base",
"LLMMathChain": "langchain.chains.llm_math.base",
"LLMRequestsChain": "langchain.chains.llm_requests",
"LLMRequestsChain": "langchain_community.chains.llm_requests",
"LLMSummarizationCheckerChain": "langchain.chains.llm_summarization_checker.base",
"load_chain": "langchain.chains.loading",
"MapReduceChain": "langchain.chains.mapreduce",

View File

@@ -343,7 +343,6 @@ addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused
markers = [
"requires: mark tests as requiring a specific library",
"scheduled: mark tests to run in scheduled testing",
"community: mark tests that require langchain-community to be installed",
"compile: mark placeholder test used to compile integration tests without running them"
]
asyncio_mode = "auto"

View File

@@ -1,30 +0,0 @@
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool
from langchain.agents.agent_types import AgentType
from langchain.agents.initialize import initialize_agent, load_agent
pytest.importorskip("langchain_community")
def test_mrkl_serialization() -> None:
agent = initialize_agent(
[
Tool(
name="Test tool",
func=lambda x: x,
description="Test description",
)
],
FakeListLLM(responses=[]),
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
with TemporaryDirectory() as tempdir:
file = Path(tempdir) / "agent.json"
agent.save_agent(file)
load_agent(file)

View File

@@ -1,61 +0,0 @@
"""Test the loading function for evaluators."""
from typing import List
import pytest
from langchain_core.embeddings import FakeEmbeddings
from langchain.evaluation.loading import EvaluatorType, load_evaluators
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("evaluator_type", EvaluatorType)
def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
"""Test loading evaluators."""
fake_llm = FakeChatModel()
embeddings = FakeEmbeddings(size=32)
load_evaluators([evaluator_type], llm=fake_llm, embeddings=embeddings)
# Test as string
load_evaluators(
[evaluator_type.value], # type: ignore
llm=fake_llm,
embeddings=embeddings,
)
@pytest.mark.community
@pytest.mark.parametrize(
"evaluator_types",
[
[EvaluatorType.LABELED_CRITERIA],
[EvaluatorType.LABELED_PAIRWISE_STRING],
[EvaluatorType.LABELED_SCORE_STRING],
[EvaluatorType.QA],
[EvaluatorType.CONTEXT_QA],
[EvaluatorType.COT_QA],
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
[
EvaluatorType.COT_QA,
EvaluatorType.LABELED_CRITERIA,
EvaluatorType.LABELED_PAIRWISE_STRING,
],
[EvaluatorType.JSON_EQUALITY],
[EvaluatorType.EXACT_MATCH, EvaluatorType.REGEX_MATCH],
],
)
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
"""Test loading evaluators."""
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
evaluators = load_evaluators(
evaluator_types,
llm=fake_llm,
)
for evaluator in evaluators:
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
raise ValueError("Evaluator is not a [pairwise]string evaluator")
assert evaluator.requires_reference

View File

@@ -1,14 +1,4 @@
"""Test base LLM functionality."""
import importlib
import pytest
from sqlalchemy import Column, Integer, Sequence, String, create_engine
try:
from sqlalchemy.orm import declarative_base
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
from langchain_core.caches import InMemoryCache
from langchain_core.outputs import Generation, LLMResult
@@ -50,48 +40,3 @@ def test_caching() -> None:
llm_output=None,
)
assert output == expected_output
@pytest.mark.skipif(
importlib.util.find_spec("langchain_community") is None,
reason="langchain_community not installed",
)
def test_custom_caching() -> None:
"""Test custom_caching behavior."""
Base = declarative_base()
class FulltextLLMCache(Base): # type: ignore
"""Postgres table for fulltext-indexed LLM Cache."""
__tablename__ = "llm_cache_fulltext"
id = Column(Integer, Sequence("cache_id"), primary_key=True)
prompt = Column(String, nullable=False)
llm = Column(String, nullable=False)
idx = Column(Integer)
response = Column(String)
engine = create_engine("sqlite://")
from langchain_community.cache import SQLAlchemyCache
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == expected_cache_output
set_llm_cache(None)
expected_generations = [
[Generation(text="fizz")],
[Generation(text="foo")],
[Generation(text="fizz")],
]
expected_output = LLMResult(
generations=expected_generations,
llm_output=None,
)
assert output == expected_output

View File

@@ -1,6 +1,3 @@
import pytest
from langchain_core.language_models import BaseLLM
from langchain import llms
EXPECT_ALL = [
@@ -91,10 +88,3 @@ EXPECT_ALL = [
def test_all_imports() -> None:
"""Simple test to make sure all things can be imported."""
assert set(llms.__all__) == set(EXPECT_ALL)
@pytest.mark.community
def test_all_subclasses() -> None:
"""Simple test to make sure all things are subclasses of BaseLLM."""
for cls in llms.__all__:
assert issubclass(getattr(llms, cls), BaseLLM)

View File

@@ -8,12 +8,7 @@ from unittest.mock import patch
import pytest
from langchain_core.load.dump import dumps
from langchain_core.load.serializable import Serializable
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.tracers.langchain import LangChainTracer
from langchain.chains.llm import LLMChain
class Person(Serializable):
@@ -74,105 +69,6 @@ def test_typeerror() -> None:
)
@pytest.mark.community
@pytest.mark.requires("openai")
def test_serialize_openai_llm(snapshot: Any) -> None:
from langchain_community.llms.openai import OpenAI
with patch.dict(os.environ, {"LANGCHAIN_API_KEY": "test-api-key"}):
llm = OpenAI( # type: ignore[call-arg]
model="davinci",
temperature=0.5,
openai_api_key="hello",
# This is excluded from serialization
callbacks=[LangChainTracer()],
)
llm.temperature = 0.7 # this is reflected in serialization
assert dumps(llm, pretty=True) == snapshot
@pytest.mark.community
@pytest.mark.requires("openai")
def test_serialize_llmchain(snapshot: Any) -> None:
from langchain_community.llms.openai import OpenAI
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.community
@pytest.mark.requires("openai")
def test_serialize_llmchain_env() -> None:
from langchain_community.llms.openai import OpenAI
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm_2 = OpenAI(model="davinci", temperature=0.5) # type: ignore[call-arg]
prompt_2 = PromptTemplate.from_template("hello {name}!")
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.community
@pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None:
from langchain_community.chat_models.openai import ChatOpenAI
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
prompt = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
import os
has_env = "OPENAI_API_KEY" in os.environ
if not has_env:
os.environ["OPENAI_API_KEY"] = "env_variable"
llm_2 = ChatOpenAI(model="davinci", temperature=0.5) # type: ignore[call-arg]
prompt_2 = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
if not has_env:
del os.environ["OPENAI_API_KEY"]
@pytest.mark.community
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
from langchain_community.llms.openai import OpenAI
llm = OpenAI( # type: ignore[call-arg]
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
def test_person_with_kwargs(snapshot: Any) -> None:
person = Person(secret="hello")
assert dumps(person, separators=(",", ":")) == snapshot

View File

@@ -7,12 +7,10 @@ from langchain_core.prompts.prompt import PromptTemplate
from langchain.chains.llm import LLMChain
pytest.importorskip(
"langchain_community",
)
pytest.importorskip("langchain_openai", reason="langchain_openai not installed")
pytest.importorskip("langchain_community", reason="langchain_community not installed")
from langchain_community.llms.openai import ( # noqa: E402, # ignore: community-import
from langchain_community.llms.openai import ( # noqa: E402 # ignore: community-import
OpenAI as CommunityOpenAI,
)

View File

@@ -1,117 +0,0 @@
import importlib
import inspect
import pkgutil
from types import ModuleType
import pytest
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
def import_all_modules(package_name: str) -> dict:
package = importlib.import_module(package_name)
classes: dict = {}
def _handle_module(module: ModuleType) -> None:
# Iterate over all members of the module
names = dir(module)
if hasattr(module, "__all__"):
names += list(module.__all__)
names = sorted(set(names))
for name in names:
# Check if it's a class or function
attr = getattr(module, name)
if not inspect.isclass(attr):
continue
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
continue
if (
isinstance(attr.is_lc_serializable(), bool) # type: ignore
and attr.is_lc_serializable() # type: ignore
):
key = tuple(attr.lc_id()) # type: ignore
value = tuple(attr.__module__.split(".") + [attr.__name__])
if key in classes and classes[key] != value:
raise ValueError
classes[key] = value
_handle_module(package)
for importer, modname, ispkg in pkgutil.walk_packages(
package.__path__, package.__name__ + "."
):
try:
module = importlib.import_module(modname)
except ModuleNotFoundError:
continue
_handle_module(module)
return classes
@pytest.mark.community
def test_import_all_modules() -> None:
"""Test import all modules works as expected"""
all_modules = import_all_modules("langchain")
filtered_modules = [
k
for k in all_modules
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
]
# This test will need to be updated if new serializable classes are added
# to community
assert sorted(filtered_modules) == sorted(
[
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
("langchain", "chat_models", "bedrock", "BedrockChat"),
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
("langchain", "chat_models", "fireworks", "ChatFireworks"),
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
("langchain", "chat_models", "openai", "ChatOpenAI"),
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
]
)
@pytest.mark.community
def test_serializable_mapping() -> None:
to_skip = {
# This should have had a different namespace, as it was never
# exported from the langchain module, but we keep for whoever has
# already serialized it.
("langchain", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
# This is not exported from langchain, only langchain_core
("langchain_core", "prompts", "structured", "StructuredPrompt"): (
"langchain_core",
"prompts",
"structured",
"StructuredPrompt",
),
}
serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference(
set(serializable_modules).union(to_skip)
)
assert missing == set()
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
assert extra == set()
for k, import_path in serializable_modules.items():
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
assert list(k) == cls.lc_id()

View File

@@ -1,89 +0,0 @@
import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import FakeEmbeddings
from langchain.retrievers.ensemble import EnsembleRetriever
pytest.importorskip("langchain_community")
@pytest.mark.requires("rank_bm25")
def test_ensemble_retriever_get_relevant_docs() -> None:
doc_list = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
from langchain_community.retrievers import BM25Retriever
dummy_retriever = BM25Retriever.from_texts(doc_list)
dummy_retriever.k = 1
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
retrievers=[dummy_retriever, dummy_retriever]
)
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 1
@pytest.mark.requires("rank_bm25")
def test_weighted_reciprocal_rank() -> None:
doc1 = Document(page_content="1")
doc2 = Document(page_content="2")
from langchain_community.retrievers import BM25Retriever
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
)
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "2"
assert result[1].page_content == "1"
ensemble_retriever.weights = [0.5, 0.4]
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
assert result[0].page_content == "1"
assert result[1].page_content == "2"
@pytest.mark.requires("rank_bm25", "sklearn")
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
doc_list_a = [
"I like apples",
"I like oranges",
"Apples and oranges are fruits",
]
doc_list_b = [
"I like melons",
"I like pineapples",
"Melons and pineapples are fruits",
]
doc_list_c = [
"I like avocados",
"I like strawberries",
"Avocados and strawberries are fruits",
]
from langchain_community.retrievers import (
BM25Retriever,
KNNRetriever,
TFIDFRetriever,
)
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
dummy_retriever.k = 1
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
tfidf_retriever.k = 1
knn_retriever = KNNRetriever.from_texts(
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
)
knn_retriever.k = 1
ensemble_retriever = EnsembleRetriever(
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
weights=[0.6, 0.3, 0.1],
)
docs = ensemble_retriever.invoke("I like apples")
assert len(docs) == 3

View File

@@ -92,24 +92,3 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"requests-mock",
]
)
@pytest.mark.community
def test_imports() -> None:
"""Test that you can import all top level things okay."""
from langchain_community.callbacks import OpenAICallbackHandler # noqa: F401
from langchain_community.chat_models import ChatOpenAI # noqa: F401
from langchain_community.document_loaders import BSHTMLLoader # noqa: F401
from langchain_community.embeddings import OpenAIEmbeddings # noqa: F401
from langchain_community.llms import OpenAI # noqa: F401
from langchain_community.retrievers import VespaRetriever # noqa: F401
from langchain_community.tools import DuckDuckGoSearchResults # noqa: F401
from langchain_community.utilities import (
SearchApiAPIWrapper, # noqa: F401
SerpAPIWrapper, # noqa: F401
)
from langchain_community.vectorstores import FAISS # noqa: F401
from langchain_core.prompts import BasePromptTemplate # noqa: F401
from langchain.agents import OpenAIFunctionsAgent # noqa: F401
from langchain.chains import LLMChain # noqa: F401

View File

@@ -1,13 +1,12 @@
import importlib
from pathlib import Path
import pytest
# Attempt to recursively import all modules in langchain
PKG_ROOT = Path(__file__).parent.parent.parent
COMMUNITY_NOT_INSTALLED = importlib.util.find_spec("langchain_community") is None
@pytest.mark.community
def test_import_all() -> None:
"""Generate the public API for this package."""
library_code = PKG_ROOT / "langchain"
@@ -26,8 +25,13 @@ def test_import_all() -> None:
for name in all:
# Attempt to import the name from the module
obj = getattr(mod, name)
assert obj is not None
try:
obj = getattr(mod, name)
assert obj is not None
except ModuleNotFoundError as e:
# If the module is not installed, we suppress the error
if "Module langchain_community" in str(e) and COMMUNITY_NOT_INSTALLED:
pass
def test_import_all_using_dir() -> None:
@@ -42,6 +46,9 @@ def test_import_all_using_dir() -> None:
# Without init
module_name = module_name.rsplit(".", 1)[0]
if module_name.startswith("langchain_community.") and COMMUNITY_NOT_INSTALLED:
continue
try:
mod = importlib.import_module(module_name)
except ModuleNotFoundError as e:

View File

@@ -1,16 +0,0 @@
import pytest
from langchain_core.vectorstores import VectorStore
from langchain import vectorstores
@pytest.mark.community
def test_all_imports() -> None:
"""Simple test to make sure all things can be imported."""
for cls in vectorstores.__all__:
if cls not in [
"AlibabaCloudOpenSearchSettings",
"ClickhouseSettings",
"MyScaleSettings",
]:
assert issubclass(getattr(vectorstores, cls), VectorStore)