mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-06 19:48:26 +00:00
## Description Contains 2 updates to for integration tests to run on langchain's CI. Addendum to #25057 to get release github action to succeed.
177 lines
5.1 KiB
Python
177 lines
5.1 KiB
Python
import os
|
|
from time import sleep
|
|
from typing import List
|
|
|
|
import pytest
|
|
from langchain_core.documents import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
from pymongo import MongoClient
|
|
from pymongo.collection import Collection
|
|
|
|
from langchain_mongodb import index
|
|
from langchain_mongodb.retrievers import (
|
|
MongoDBAtlasFullTextSearchRetriever,
|
|
MongoDBAtlasHybridSearchRetriever,
|
|
)
|
|
|
|
from ..utils import ConsistentFakeEmbeddings, PatchedMongoDBAtlasVectorSearch
|
|
|
|
CONNECTION_STRING = os.environ.get("MONGODB_ATLAS_URI")
|
|
DB_NAME = "langchain_test_db"
|
|
COLLECTION_NAME = "langchain_test_retrievers"
|
|
VECTOR_INDEX_NAME = "vector_index"
|
|
EMBEDDING_FIELD = "embedding"
|
|
PAGE_CONTENT_FIELD = "text"
|
|
SEARCH_INDEX_NAME = "text_index"
|
|
|
|
DIMENSIONS = 1536
|
|
TIMEOUT = 60.0
|
|
INTERVAL = 0.5
|
|
|
|
|
|
@pytest.fixture
|
|
def example_documents() -> List[Document]:
|
|
return [
|
|
Document(page_content="In 2023, I visited Paris"),
|
|
Document(page_content="In 2022, I visited New York"),
|
|
Document(page_content="In 2021, I visited New Orleans"),
|
|
Document(page_content="Sandwiches are beautiful. Sandwiches are fine."),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def embedding_openai() -> Embeddings:
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
try:
|
|
return OpenAIEmbeddings(
|
|
openai_api_key=os.environ["OPENAI_API_KEY"], # type: ignore # noqa
|
|
model="text-embedding-3-small",
|
|
)
|
|
except Exception:
|
|
return ConsistentFakeEmbeddings(DIMENSIONS)
|
|
|
|
|
|
@pytest.fixture
|
|
def collection() -> Collection:
|
|
"""A Collection with both a Vector and a Full-text Search Index"""
|
|
client: MongoClient = MongoClient(CONNECTION_STRING)
|
|
if COLLECTION_NAME not in client[DB_NAME].list_collection_names():
|
|
clxn = client[DB_NAME].create_collection(COLLECTION_NAME)
|
|
else:
|
|
clxn = client[DB_NAME][COLLECTION_NAME]
|
|
|
|
clxn.delete_many({})
|
|
|
|
if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
|
|
index.create_vector_search_index(
|
|
collection=clxn,
|
|
index_name=VECTOR_INDEX_NAME,
|
|
dimensions=DIMENSIONS,
|
|
path="embedding",
|
|
similarity="cosine",
|
|
wait_until_complete=TIMEOUT,
|
|
)
|
|
|
|
if not any([SEARCH_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]):
|
|
index.create_fulltext_search_index(
|
|
collection=clxn,
|
|
index_name=SEARCH_INDEX_NAME,
|
|
field=PAGE_CONTENT_FIELD,
|
|
wait_until_complete=TIMEOUT,
|
|
)
|
|
|
|
return clxn
|
|
|
|
|
|
def test_hybrid_retriever(
|
|
embedding_openai: Embeddings,
|
|
collection: Collection,
|
|
example_documents: List[Document],
|
|
) -> None:
|
|
"""Test basic usage of MongoDBAtlasHybridSearchRetriever"""
|
|
|
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
collection=collection,
|
|
embedding=embedding_openai,
|
|
index_name=VECTOR_INDEX_NAME,
|
|
text_key=PAGE_CONTENT_FIELD,
|
|
)
|
|
|
|
vectorstore.add_documents(example_documents)
|
|
|
|
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
|
|
retriever = MongoDBAtlasHybridSearchRetriever(
|
|
vectorstore=vectorstore,
|
|
search_index_name=SEARCH_INDEX_NAME,
|
|
top_k=3,
|
|
)
|
|
|
|
query1 = "What was the latest city that I visited?"
|
|
results = retriever.invoke(query1)
|
|
assert len(results) == 3
|
|
assert "Paris" in results[0].page_content
|
|
|
|
query2 = "When was the last time I visited new orleans?"
|
|
results = retriever.invoke(query2)
|
|
assert "New Orleans" in results[0].page_content
|
|
|
|
|
|
def test_fulltext_retriever(
|
|
collection: Collection,
|
|
example_documents: List[Document],
|
|
) -> None:
|
|
"""Test result of performing fulltext search
|
|
|
|
Independent of the VectorStore, one adds documents
|
|
via MongoDB's Collection API
|
|
"""
|
|
#
|
|
|
|
collection.insert_many(
|
|
[{PAGE_CONTENT_FIELD: doc.page_content} for doc in example_documents]
|
|
)
|
|
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
|
|
retriever = MongoDBAtlasFullTextSearchRetriever(
|
|
collection=collection,
|
|
search_index_name=SEARCH_INDEX_NAME,
|
|
search_field=PAGE_CONTENT_FIELD,
|
|
)
|
|
|
|
query = "When was the last time I visited new orleans?"
|
|
results = retriever.invoke(query)
|
|
assert "New Orleans" in results[0].page_content
|
|
assert "score" in results[0].metadata
|
|
|
|
|
|
def test_vector_retriever(
|
|
embedding_openai: Embeddings,
|
|
collection: Collection,
|
|
example_documents: List[Document],
|
|
) -> None:
|
|
"""Test VectorStoreRetriever"""
|
|
|
|
vectorstore = PatchedMongoDBAtlasVectorSearch(
|
|
collection=collection,
|
|
embedding=embedding_openai,
|
|
index_name=VECTOR_INDEX_NAME,
|
|
text_key=PAGE_CONTENT_FIELD,
|
|
)
|
|
|
|
vectorstore.add_documents(example_documents)
|
|
|
|
sleep(TIMEOUT) # Wait for documents to be sync'd
|
|
|
|
retriever = vectorstore.as_retriever()
|
|
|
|
query1 = "What was the latest city that I visited?"
|
|
results = retriever.invoke(query1)
|
|
assert len(results) == 4
|
|
assert "Paris" in results[0].page_content
|
|
|
|
query2 = "When was the last time I visited new orleans?"
|
|
results = retriever.invoke(query2)
|
|
assert "New Orleans" in results[0].page_content
|