langchain[patch],community[minor]: Move some unit tests from langchain to community, use core for fake models (#21190)

This commit is contained in:
Eugene Yurtsev
2024-05-02 09:57:52 -04:00
committed by GitHub
parent c306364b06
commit c9119b0e75
14 changed files with 20 additions and 16 deletions

View File

@@ -177,7 +177,7 @@ def _import_edenai() -> Any:
def _import_fake() -> Any:
from langchain_community.llms.fake import FakeListLLM
from langchain_core.language_models import FakeListLLM
return FakeListLLM

View File

@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
from langchain._api import create_importer
if TYPE_CHECKING:
from langchain_community.llms import FakeListLLM
from langchain_community.llms.fake import FakeStreamingListLLM
from langchain_core.language_models import FakeListLLM
# Create a way to dynamically look up deprecated imports.
# Used to consolidate logic for raising deprecation warnings and

View File

@@ -1,7 +1,7 @@
from uuid import UUID
import pytest
from langchain_community.llms import FakeListLLM
from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool
from langchain.agents import (

View File

@@ -2,9 +2,9 @@
from typing import Union
from langchain_community.llms.fake import FakeListLLM
from langchain_core.agents import AgentAction
from langchain_core.documents import Document
from langchain_core.language_models import FakeListLLM
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.tools import Tool

View File

@@ -1,7 +1,7 @@
from pathlib import Path
from tempfile import TemporaryDirectory
from langchain_community.llms.fake import FakeListLLM
from langchain_core.language_models import FakeListLLM
from langchain_core.tools import Tool
from langchain.agents.agent_types import AgentType

View File

@@ -1,6 +1,6 @@
"""Test conversation chain and memory."""
from langchain_community.llms.fake import FakeListLLM
from langchain_core.documents import Document
from langchain_core.language_models import FakeListLLM
from langchain.chains.conversational_retrieval.base import (
ConversationalRetrievalChain,

View File

@@ -1,5 +1,5 @@
from langchain_community.llms.fake import FakeListLLM
from langchain_core.documents import Document
from langchain_core.language_models import FakeListLLM
from langchain_core.prompts import PromptTemplate
from langchain.chains import create_history_aware_retriever

View File

@@ -1,6 +1,6 @@
"""Test conversation chain and memory."""
from langchain_community.llms.fake import FakeListLLM
from langchain_core.documents import Document
from langchain_core.language_models import FakeListLLM
from langchain_core.prompts.prompt import PromptTemplate
from langchain.chains import create_retrieval_chain

View File

@@ -2,7 +2,7 @@
from typing import List
import pytest
from langchain_community.embeddings.fake import FakeEmbeddings
from langchain_core.embeddings import FakeEmbeddings
from langchain.evaluation.loading import EvaluatorType, load_evaluators
from langchain.evaluation.schema import PairwiseStringEvaluator, StringEvaluator

View File

@@ -1,52 +0,0 @@
import os
import pytest
from pytest_mock import MockerFixture
from langchain.retrievers.document_compressors import CohereRerank
from langchain.schema import Document
os.environ["COHERE_API_KEY"] = "foo"
@pytest.mark.requires("cohere")
def test_init() -> None:
CohereRerank()
CohereRerank(
top_n=5, model="rerank-english_v2.0", cohere_api_key="foo", user_agent="bar"
)
@pytest.mark.requires("cohere")
def test_rerank(mocker: MockerFixture) -> None:
mock_client = mocker.MagicMock()
mock_result = mocker.MagicMock()
mock_result.results = [
mocker.MagicMock(index=0, relevance_score=0.8),
mocker.MagicMock(index=1, relevance_score=0.6),
]
mock_client.rerank.return_value = mock_result
test_documents = [
Document(page_content="This is a test document."),
Document(page_content="Another test document."),
]
test_query = "Test query"
mocker.patch("cohere.Client", return_value=mock_client)
reranker = CohereRerank(cohere_api_key="foo")
results = reranker.rerank(test_documents, test_query)
mock_client.rerank.assert_called_once_with(
query=test_query,
documents=[doc.page_content for doc in test_documents],
model="rerank-english-v2.0",
top_n=3,
max_chunks_per_doc=None,
)
assert results == [
{"index": 0, "relevance_score": 0.8},
{"index": 1, "relevance_score": 0.6},
]

View File

@@ -1,34 +0,0 @@
"""Integration test for CrossEncoderReranker."""
from typing import List
from langchain_community.cross_encoders import FakeCrossEncoder
from langchain_core.documents import Document
from langchain.retrievers.document_compressors import CrossEncoderReranker
def test_rerank() -> None:
texts = [
"aaa1",
"bbb1",
"aaa2",
"bbb2",
"aaa3",
"bbb3",
]
docs = list(map(lambda text: Document(page_content=text), texts))
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "bbb2")
actual = list(map(lambda doc: doc.page_content, actual_docs))
expected_returned = ["bbb2", "bbb1", "bbb3"]
expected_not_returned = ["aaa1", "aaa2", "aaa3"]
assert all([text in actual for text in expected_returned])
assert all([text not in actual for text in expected_not_returned])
assert actual[0] == "bbb2"
def test_rerank_empty() -> None:
docs: List[Document] = []
compressor = CrossEncoderReranker(model=FakeCrossEncoder())
actual_docs = compressor.compress_documents(docs, "query")
assert len(actual_docs) == 0

View File

@@ -4,9 +4,8 @@ from typing import Dict, Generator, List, Union
import pytest
from _pytest.fixtures import FixtureRequest
from langchain_community.chat_models import FakeListChatModel
from langchain_community.llms import FakeListLLM
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import FakeListChatModel, FakeListLLM
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import BaseLLM
from langchain_core.load import dumps