mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 23:41:46 +00:00
langchain[patch],community[patch]: Move unit tests that depend on community to community (#21685)
This commit is contained in:
@@ -23,7 +23,7 @@ from langchain._api import create_importer
|
||||
|
||||
_module_lookup = {
|
||||
"APIChain": "langchain.chains.api.base",
|
||||
"OpenAPIEndpointChain": "langchain.chains.api.openapi.chain",
|
||||
"OpenAPIEndpointChain": "langchain_community.chains.openapi.chain",
|
||||
"AnalyzeDocumentChain": "langchain.chains.combine_documents.base",
|
||||
"MapReduceDocumentsChain": "langchain.chains.combine_documents.map_reduce",
|
||||
"MapRerankDocumentsChain": "langchain.chains.combine_documents.map_rerank",
|
||||
@@ -36,23 +36,23 @@ _module_lookup = {
|
||||
"ConversationalRetrievalChain": "langchain.chains.conversational_retrieval.base",
|
||||
"generate_example": "langchain.chains.example_generator",
|
||||
"FlareChain": "langchain.chains.flare.base",
|
||||
"ArangoGraphQAChain": "langchain.chains.graph_qa.arangodb",
|
||||
"GraphQAChain": "langchain.chains.graph_qa.base",
|
||||
"GraphCypherQAChain": "langchain.chains.graph_qa.cypher",
|
||||
"FalkorDBQAChain": "langchain.chains.graph_qa.falkordb",
|
||||
"HugeGraphQAChain": "langchain.chains.graph_qa.hugegraph",
|
||||
"KuzuQAChain": "langchain.chains.graph_qa.kuzu",
|
||||
"NebulaGraphQAChain": "langchain.chains.graph_qa.nebulagraph",
|
||||
"NeptuneOpenCypherQAChain": "langchain.chains.graph_qa.neptune_cypher",
|
||||
"NeptuneSparqlQAChain": "langchain.chains.graph_qa.neptune_sparql",
|
||||
"OntotextGraphDBQAChain": "langchain.chains.graph_qa.ontotext_graphdb",
|
||||
"GraphSparqlQAChain": "langchain.chains.graph_qa.sparql",
|
||||
"ArangoGraphQAChain": "langchain_community.chains.graph_qa.arangodb",
|
||||
"GraphQAChain": "langchain_community.chains.graph_qa.base",
|
||||
"GraphCypherQAChain": "langchain_community.chains.graph_qa.cypher",
|
||||
"FalkorDBQAChain": "langchain_community.chains.graph_qa.falkordb",
|
||||
"HugeGraphQAChain": "langchain_community.chains.graph_qa.hugegraph",
|
||||
"KuzuQAChain": "langchain_community.chains.graph_qa.kuzu",
|
||||
"NebulaGraphQAChain": "langchain_community.chains.graph_qa.nebulagraph",
|
||||
"NeptuneOpenCypherQAChain": "langchain_community.chains.graph_qa.neptune_cypher",
|
||||
"NeptuneSparqlQAChain": "langchain_community.chains.graph_qa.neptune_sparql",
|
||||
"OntotextGraphDBQAChain": "langchain_community.chains.graph_qa.ontotext_graphdb",
|
||||
"GraphSparqlQAChain": "langchain_community.chains.graph_qa.sparql",
|
||||
"create_history_aware_retriever": "langchain.chains.history_aware_retriever",
|
||||
"HypotheticalDocumentEmbedder": "langchain.chains.hyde.base",
|
||||
"LLMChain": "langchain.chains.llm",
|
||||
"LLMCheckerChain": "langchain.chains.llm_checker.base",
|
||||
"LLMMathChain": "langchain.chains.llm_math.base",
|
||||
"LLMRequestsChain": "langchain.chains.llm_requests",
|
||||
"LLMRequestsChain": "langchain_community.chains.llm_requests",
|
||||
"LLMSummarizationCheckerChain": "langchain.chains.llm_summarization_checker.base",
|
||||
"load_chain": "langchain.chains.loading",
|
||||
"MapReduceChain": "langchain.chains.mapreduce",
|
||||
|
@@ -343,7 +343,6 @@ addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"scheduled: mark tests to run in scheduled testing",
|
||||
"community: mark tests that require langchain-community to be installed",
|
||||
"compile: mark placeholder test used to compile integration tests without running them"
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
@@ -1,30 +0,0 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.tools import Tool
|
||||
|
||||
from langchain.agents.agent_types import AgentType
|
||||
from langchain.agents.initialize import initialize_agent, load_agent
|
||||
|
||||
pytest.importorskip("langchain_community")
|
||||
|
||||
|
||||
def test_mrkl_serialization() -> None:
|
||||
agent = initialize_agent(
|
||||
[
|
||||
Tool(
|
||||
name="Test tool",
|
||||
func=lambda x: x,
|
||||
description="Test description",
|
||||
)
|
||||
],
|
||||
FakeListLLM(responses=[]),
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
with TemporaryDirectory() as tempdir:
|
||||
file = Path(tempdir) / "agent.json"
|
||||
agent.save_agent(file)
|
||||
load_agent(file)
|
@@ -1,61 +0,0 @@
|
||||
"""Test the loading function for evaluators."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import FakeEmbeddings
|
||||
|
||||
from langchain.evaluation.loading import EvaluatorType, load_evaluators
|
||||
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator
|
||||
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("evaluator_type", EvaluatorType)
|
||||
def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
||||
"""Test loading evaluators."""
|
||||
fake_llm = FakeChatModel()
|
||||
embeddings = FakeEmbeddings(size=32)
|
||||
load_evaluators([evaluator_type], llm=fake_llm, embeddings=embeddings)
|
||||
|
||||
# Test as string
|
||||
load_evaluators(
|
||||
[evaluator_type.value], # type: ignore
|
||||
llm=fake_llm,
|
||||
embeddings=embeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.parametrize(
|
||||
"evaluator_types",
|
||||
[
|
||||
[EvaluatorType.LABELED_CRITERIA],
|
||||
[EvaluatorType.LABELED_PAIRWISE_STRING],
|
||||
[EvaluatorType.LABELED_SCORE_STRING],
|
||||
[EvaluatorType.QA],
|
||||
[EvaluatorType.CONTEXT_QA],
|
||||
[EvaluatorType.COT_QA],
|
||||
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
|
||||
[
|
||||
EvaluatorType.COT_QA,
|
||||
EvaluatorType.LABELED_CRITERIA,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||
],
|
||||
[EvaluatorType.JSON_EQUALITY],
|
||||
[EvaluatorType.EXACT_MATCH, EvaluatorType.REGEX_MATCH],
|
||||
],
|
||||
)
|
||||
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
|
||||
"""Test loading evaluators."""
|
||||
fake_llm = FakeLLM(
|
||||
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
|
||||
)
|
||||
evaluators = load_evaluators(
|
||||
evaluator_types,
|
||||
llm=fake_llm,
|
||||
)
|
||||
for evaluator in evaluators:
|
||||
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
|
||||
raise ValueError("Evaluator is not a [pairwise]string evaluator")
|
||||
assert evaluator.requires_reference
|
@@ -1,14 +1,4 @@
|
||||
"""Test base LLM functionality."""
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, Integer, Sequence, String, create_engine
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
|
||||
@@ -50,48 +40,3 @@ def test_caching() -> None:
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
importlib.util.find_spec("langchain_community") is None,
|
||||
reason="langchain_community not installed",
|
||||
)
|
||||
def test_custom_caching() -> None:
|
||||
"""Test custom_caching behavior."""
|
||||
Base = declarative_base()
|
||||
|
||||
class FulltextLLMCache(Base): # type: ignore
|
||||
"""Postgres table for fulltext-indexed LLM Cache."""
|
||||
|
||||
__tablename__ = "llm_cache_fulltext"
|
||||
id = Column(Integer, Sequence("cache_id"), primary_key=True)
|
||||
prompt = Column(String, nullable=False)
|
||||
llm = Column(String, nullable=False)
|
||||
idx = Column(Integer)
|
||||
response = Column(String)
|
||||
|
||||
engine = create_engine("sqlite://")
|
||||
|
||||
from langchain_community.cache import SQLAlchemyCache
|
||||
|
||||
set_llm_cache(SQLAlchemyCache(engine, FulltextLLMCache))
|
||||
llm = FakeLLM()
|
||||
params = llm.dict()
|
||||
params["stop"] = None
|
||||
llm_string = str(sorted([(k, v) for k, v in params.items()]))
|
||||
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
|
||||
output = llm.generate(["foo", "bar", "foo"])
|
||||
expected_cache_output = [Generation(text="foo")]
|
||||
cache_output = get_llm_cache().lookup("bar", llm_string)
|
||||
assert cache_output == expected_cache_output
|
||||
set_llm_cache(None)
|
||||
expected_generations = [
|
||||
[Generation(text="fizz")],
|
||||
[Generation(text="foo")],
|
||||
[Generation(text="fizz")],
|
||||
]
|
||||
expected_output = LLMResult(
|
||||
generations=expected_generations,
|
||||
llm_output=None,
|
||||
)
|
||||
assert output == expected_output
|
||||
|
@@ -1,6 +1,3 @@
|
||||
import pytest
|
||||
from langchain_core.language_models import BaseLLM
|
||||
|
||||
from langchain import llms
|
||||
|
||||
EXPECT_ALL = [
|
||||
@@ -91,10 +88,3 @@ EXPECT_ALL = [
|
||||
def test_all_imports() -> None:
|
||||
"""Simple test to make sure all things can be imported."""
|
||||
assert set(llms.__all__) == set(EXPECT_ALL)
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_all_subclasses() -> None:
|
||||
"""Simple test to make sure all things are subclasses of BaseLLM."""
|
||||
for cls in llms.__all__:
|
||||
assert issubclass(getattr(llms, cls), BaseLLM)
|
||||
|
@@ -8,12 +8,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from langchain_core.load.dump import dumps
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
|
||||
class Person(Serializable):
|
||||
@@ -74,105 +69,6 @@ def test_typeerror() -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_openai_llm(snapshot: Any) -> None:
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
with patch.dict(os.environ, {"LANGCHAIN_API_KEY": "test-api-key"}):
|
||||
llm = OpenAI( # type: ignore[call-arg]
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
# This is excluded from serialization
|
||||
callbacks=[LangChainTracer()],
|
||||
)
|
||||
llm.temperature = 0.7 # this is reflected in serialization
|
||||
assert dumps(llm, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain(snapshot: Any) -> None:
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_env() -> None:
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
|
||||
import os
|
||||
|
||||
has_env = "OPENAI_API_KEY" in os.environ
|
||||
if not has_env:
|
||||
os.environ["OPENAI_API_KEY"] = "env_variable"
|
||||
|
||||
llm_2 = OpenAI(model="davinci", temperature=0.5) # type: ignore[call-arg]
|
||||
prompt_2 = PromptTemplate.from_template("hello {name}!")
|
||||
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
|
||||
|
||||
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
|
||||
|
||||
if not has_env:
|
||||
del os.environ["OPENAI_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
||||
from langchain_community.chat_models.openai import ChatOpenAI
|
||||
|
||||
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
import os
|
||||
|
||||
has_env = "OPENAI_API_KEY" in os.environ
|
||||
if not has_env:
|
||||
os.environ["OPENAI_API_KEY"] = "env_variable"
|
||||
|
||||
llm_2 = ChatOpenAI(model="davinci", temperature=0.5) # type: ignore[call-arg]
|
||||
prompt_2 = ChatPromptTemplate.from_messages(
|
||||
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||
)
|
||||
chain_2 = LLMChain(llm=llm_2, prompt=prompt_2)
|
||||
|
||||
assert dumps(chain_2, pretty=True) == dumps(chain, pretty=True)
|
||||
|
||||
if not has_env:
|
||||
del os.environ["OPENAI_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||
from langchain_community.llms.openai import OpenAI
|
||||
|
||||
llm = OpenAI( # type: ignore[call-arg]
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
def test_person_with_kwargs(snapshot: Any) -> None:
|
||||
person = Person(secret="hello")
|
||||
assert dumps(person, separators=(",", ":")) == snapshot
|
||||
|
@@ -7,12 +7,10 @@ from langchain_core.prompts.prompt import PromptTemplate
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
|
||||
pytest.importorskip(
|
||||
"langchain_community",
|
||||
)
|
||||
pytest.importorskip("langchain_openai", reason="langchain_openai not installed")
|
||||
pytest.importorskip("langchain_community", reason="langchain_community not installed")
|
||||
|
||||
|
||||
from langchain_community.llms.openai import ( # noqa: E402, # ignore: community-import
|
||||
from langchain_community.llms.openai import ( # noqa: E402 # ignore: community-import
|
||||
OpenAI as CommunityOpenAI,
|
||||
)
|
||||
|
||||
|
@@ -1,117 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from types import ModuleType
|
||||
|
||||
import pytest
|
||||
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
|
||||
|
||||
|
||||
def import_all_modules(package_name: str) -> dict:
|
||||
package = importlib.import_module(package_name)
|
||||
classes: dict = {}
|
||||
|
||||
def _handle_module(module: ModuleType) -> None:
|
||||
# Iterate over all members of the module
|
||||
|
||||
names = dir(module)
|
||||
|
||||
if hasattr(module, "__all__"):
|
||||
names += list(module.__all__)
|
||||
|
||||
names = sorted(set(names))
|
||||
|
||||
for name in names:
|
||||
# Check if it's a class or function
|
||||
attr = getattr(module, name)
|
||||
|
||||
if not inspect.isclass(attr):
|
||||
continue
|
||||
|
||||
if not hasattr(attr, "is_lc_serializable") or not isinstance(attr, type):
|
||||
continue
|
||||
|
||||
if (
|
||||
isinstance(attr.is_lc_serializable(), bool) # type: ignore
|
||||
and attr.is_lc_serializable() # type: ignore
|
||||
):
|
||||
key = tuple(attr.lc_id()) # type: ignore
|
||||
value = tuple(attr.__module__.split(".") + [attr.__name__])
|
||||
if key in classes and classes[key] != value:
|
||||
raise ValueError
|
||||
classes[key] = value
|
||||
|
||||
_handle_module(package)
|
||||
|
||||
for importer, modname, ispkg in pkgutil.walk_packages(
|
||||
package.__path__, package.__name__ + "."
|
||||
):
|
||||
try:
|
||||
module = importlib.import_module(modname)
|
||||
except ModuleNotFoundError:
|
||||
continue
|
||||
_handle_module(module)
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_import_all_modules() -> None:
|
||||
"""Test import all modules works as expected"""
|
||||
all_modules = import_all_modules("langchain")
|
||||
filtered_modules = [
|
||||
k
|
||||
for k in all_modules
|
||||
if len(k) == 4 and tuple(k[:2]) == ("langchain", "chat_models")
|
||||
]
|
||||
# This test will need to be updated if new serializable classes are added
|
||||
# to community
|
||||
assert sorted(filtered_modules) == sorted(
|
||||
[
|
||||
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"),
|
||||
("langchain", "chat_models", "bedrock", "BedrockChat"),
|
||||
("langchain", "chat_models", "anthropic", "ChatAnthropic"),
|
||||
("langchain", "chat_models", "fireworks", "ChatFireworks"),
|
||||
("langchain", "chat_models", "google_palm", "ChatGooglePalm"),
|
||||
("langchain", "chat_models", "openai", "ChatOpenAI"),
|
||||
("langchain", "chat_models", "vertexai", "ChatVertexAI"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_serializable_mapping() -> None:
|
||||
to_skip = {
|
||||
# This should have had a different namespace, as it was never
|
||||
# exported from the langchain module, but we keep for whoever has
|
||||
# already serialized it.
|
||||
("langchain", "prompts", "image", "ImagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"image",
|
||||
"ImagePromptTemplate",
|
||||
),
|
||||
# This is not exported from langchain, only langchain_core
|
||||
("langchain_core", "prompts", "structured", "StructuredPrompt"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
"structured",
|
||||
"StructuredPrompt",
|
||||
),
|
||||
}
|
||||
serializable_modules = import_all_modules("langchain")
|
||||
|
||||
missing = set(SERIALIZABLE_MAPPING).difference(
|
||||
set(serializable_modules).union(to_skip)
|
||||
)
|
||||
assert missing == set()
|
||||
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
|
||||
assert extra == set()
|
||||
|
||||
for k, import_path in serializable_modules.items():
|
||||
import_dir, import_obj = import_path[:-1], import_path[-1]
|
||||
# Import module
|
||||
mod = importlib.import_module(".".join(import_dir))
|
||||
# Import class
|
||||
cls = getattr(mod, import_obj)
|
||||
assert list(k) == cls.lc_id()
|
@@ -1,89 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import FakeEmbeddings
|
||||
|
||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
|
||||
pytest.importorskip("langchain_community")
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25")
|
||||
def test_ensemble_retriever_get_relevant_docs() -> None:
|
||||
doc_list = [
|
||||
"I like apples",
|
||||
"I like oranges",
|
||||
"Apples and oranges are fruits",
|
||||
]
|
||||
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(doc_list)
|
||||
dummy_retriever.k = 1
|
||||
|
||||
ensemble_retriever = EnsembleRetriever( # type: ignore[call-arg]
|
||||
retrievers=[dummy_retriever, dummy_retriever]
|
||||
)
|
||||
docs = ensemble_retriever.invoke("I like apples")
|
||||
assert len(docs) == 1
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25")
|
||||
def test_weighted_reciprocal_rank() -> None:
|
||||
doc1 = Document(page_content="1")
|
||||
doc2 = Document(page_content="2")
|
||||
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(["1", "2"])
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0
|
||||
)
|
||||
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
||||
assert result[0].page_content == "2"
|
||||
assert result[1].page_content == "1"
|
||||
|
||||
ensemble_retriever.weights = [0.5, 0.4]
|
||||
result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]])
|
||||
assert result[0].page_content == "1"
|
||||
assert result[1].page_content == "2"
|
||||
|
||||
|
||||
@pytest.mark.requires("rank_bm25", "sklearn")
|
||||
def test_ensemble_retriever_get_relevant_docs_with_multiple_retrievers() -> None:
|
||||
doc_list_a = [
|
||||
"I like apples",
|
||||
"I like oranges",
|
||||
"Apples and oranges are fruits",
|
||||
]
|
||||
doc_list_b = [
|
||||
"I like melons",
|
||||
"I like pineapples",
|
||||
"Melons and pineapples are fruits",
|
||||
]
|
||||
doc_list_c = [
|
||||
"I like avocados",
|
||||
"I like strawberries",
|
||||
"Avocados and strawberries are fruits",
|
||||
]
|
||||
|
||||
from langchain_community.retrievers import (
|
||||
BM25Retriever,
|
||||
KNNRetriever,
|
||||
TFIDFRetriever,
|
||||
)
|
||||
|
||||
dummy_retriever = BM25Retriever.from_texts(doc_list_a)
|
||||
dummy_retriever.k = 1
|
||||
tfidf_retriever = TFIDFRetriever.from_texts(texts=doc_list_b)
|
||||
tfidf_retriever.k = 1
|
||||
knn_retriever = KNNRetriever.from_texts(
|
||||
texts=doc_list_c, embeddings=FakeEmbeddings(size=100)
|
||||
)
|
||||
knn_retriever.k = 1
|
||||
|
||||
ensemble_retriever = EnsembleRetriever(
|
||||
retrievers=[dummy_retriever, tfidf_retriever, knn_retriever],
|
||||
weights=[0.6, 0.3, 0.1],
|
||||
)
|
||||
docs = ensemble_retriever.invoke("I like apples")
|
||||
assert len(docs) == 3
|
@@ -92,24 +92,3 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"requests-mock",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_imports() -> None:
|
||||
"""Test that you can import all top level things okay."""
|
||||
from langchain_community.callbacks import OpenAICallbackHandler # noqa: F401
|
||||
from langchain_community.chat_models import ChatOpenAI # noqa: F401
|
||||
from langchain_community.document_loaders import BSHTMLLoader # noqa: F401
|
||||
from langchain_community.embeddings import OpenAIEmbeddings # noqa: F401
|
||||
from langchain_community.llms import OpenAI # noqa: F401
|
||||
from langchain_community.retrievers import VespaRetriever # noqa: F401
|
||||
from langchain_community.tools import DuckDuckGoSearchResults # noqa: F401
|
||||
from langchain_community.utilities import (
|
||||
SearchApiAPIWrapper, # noqa: F401
|
||||
SerpAPIWrapper, # noqa: F401
|
||||
)
|
||||
from langchain_community.vectorstores import FAISS # noqa: F401
|
||||
from langchain_core.prompts import BasePromptTemplate # noqa: F401
|
||||
|
||||
from langchain.agents import OpenAIFunctionsAgent # noqa: F401
|
||||
from langchain.chains import LLMChain # noqa: F401
|
||||
|
@@ -1,13 +1,12 @@
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Attempt to recursively import all modules in langchain
|
||||
PKG_ROOT = Path(__file__).parent.parent.parent
|
||||
|
||||
COMMUNITY_NOT_INSTALLED = importlib.util.find_spec("langchain_community") is None
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_import_all() -> None:
|
||||
"""Generate the public API for this package."""
|
||||
library_code = PKG_ROOT / "langchain"
|
||||
@@ -26,8 +25,13 @@ def test_import_all() -> None:
|
||||
|
||||
for name in all:
|
||||
# Attempt to import the name from the module
|
||||
obj = getattr(mod, name)
|
||||
assert obj is not None
|
||||
try:
|
||||
obj = getattr(mod, name)
|
||||
assert obj is not None
|
||||
except ModuleNotFoundError as e:
|
||||
# If the module is not installed, we suppress the error
|
||||
if "Module langchain_community" in str(e) and COMMUNITY_NOT_INSTALLED:
|
||||
pass
|
||||
|
||||
|
||||
def test_import_all_using_dir() -> None:
|
||||
@@ -42,6 +46,9 @@ def test_import_all_using_dir() -> None:
|
||||
# Without init
|
||||
module_name = module_name.rsplit(".", 1)[0]
|
||||
|
||||
if module_name.startswith("langchain_community.") and COMMUNITY_NOT_INSTALLED:
|
||||
continue
|
||||
|
||||
try:
|
||||
mod = importlib.import_module(module_name)
|
||||
except ModuleNotFoundError as e:
|
||||
|
@@ -1,16 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain import vectorstores
|
||||
|
||||
|
||||
@pytest.mark.community
|
||||
def test_all_imports() -> None:
|
||||
"""Simple test to make sure all things can be imported."""
|
||||
for cls in vectorstores.__all__:
|
||||
if cls not in [
|
||||
"AlibabaCloudOpenSearchSettings",
|
||||
"ClickhouseSettings",
|
||||
"MyScaleSettings",
|
||||
]:
|
||||
assert issubclass(getattr(vectorstores, cls), VectorStore)
|
Reference in New Issue
Block a user