mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-04 08:10:25 +00:00
Compare commits
17 Commits
erick/stan
...
cc/retriev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
166ab8217f | ||
|
|
46caee5680 | ||
|
|
5fb7d3b4ba | ||
|
|
402298e376 | ||
|
|
267ee9db4c | ||
|
|
bc7af5fd7e | ||
|
|
abf1f4c124 | ||
|
|
c9fc0447ec | ||
|
|
c262cef1fb | ||
|
|
26455d156d | ||
|
|
ceea324071 | ||
|
|
1544c9d050 | ||
|
|
7a78068cd4 | ||
|
|
d91fd8cdcb | ||
|
|
1f15b0885d | ||
|
|
4f16714195 | ||
|
|
f8598a7e48 |
@@ -1444,7 +1444,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
@@ -1472,7 +1476,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import itertools
|
||||
import random
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional, Set, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import DocumentSearchHit
|
||||
|
||||
from langchain_community.vectorstores import DatabricksVectorSearch
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
@@ -598,6 +599,13 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
|
||||
assert len(search_result) == len(fake_texts)
|
||||
else:
|
||||
assert len(search_result) == 0
|
||||
result_with_scores = cast(
|
||||
List[DocumentSearchHit], retriever.invoke(query, include_score=True)
|
||||
)
|
||||
for idx, result in enumerate(result_with_scores):
|
||||
assert result.score >= threshold
|
||||
assert result.page_content == search_result[idx].page_content
|
||||
assert result.metadata == search_result[idx].metadata
|
||||
|
||||
|
||||
@pytest.mark.requires("databricks", "databricks.vector_search")
|
||||
|
||||
@@ -2,8 +2,13 @@
|
||||
and their transformations.
|
||||
|
||||
"""
|
||||
from langchain_core.documents.base import Document
|
||||
from langchain_core.documents.base import Document, DocumentSearchHit
|
||||
from langchain_core.documents.compressor import BaseDocumentCompressor
|
||||
from langchain_core.documents.transformers import BaseDocumentTransformer
|
||||
|
||||
__all__ = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
|
||||
__all__ = [
|
||||
"Document",
|
||||
"DocumentSearchHit",
|
||||
"BaseDocumentTransformer",
|
||||
"BaseDocumentCompressor",
|
||||
]
|
||||
|
||||
@@ -30,3 +30,21 @@ class Document(Serializable):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document"]
|
||||
|
||||
|
||||
class DocumentSearchHit(Document):
|
||||
"""Class for storing a document and fields associated with retrieval."""
|
||||
|
||||
score: float
|
||||
"""Score associated with the document's relevance to a query."""
|
||||
type: Literal["DocumentSearchHit"] = "DocumentSearchHit" # type: ignore[assignment] # noqa: E501
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this class is serializable."""
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "document_search_hit"]
|
||||
|
||||
@@ -157,6 +157,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"base",
|
||||
"Document",
|
||||
),
|
||||
("langchain", "schema", "document_search_hit", "DocumentSearchHit"): (
|
||||
"langchain_core",
|
||||
"documents",
|
||||
"base",
|
||||
"DocumentSearchHit",
|
||||
),
|
||||
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
@@ -666,6 +672,12 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
|
||||
"base",
|
||||
"Document",
|
||||
),
|
||||
("langchain_core", "documents", "base", "DocumentSearchHit"): (
|
||||
"langchain_core",
|
||||
"documents",
|
||||
"base",
|
||||
"DocumentSearchHit",
|
||||
),
|
||||
("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): (
|
||||
"langchain_core",
|
||||
"prompts",
|
||||
|
||||
@@ -39,6 +39,7 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from langchain_core.documents import DocumentSearchHit
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
@@ -690,8 +691,17 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
return values
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
if include_score and self.search_type != "similarity_score_threshold":
|
||||
raise ValueError(
|
||||
"include_score is only supported "
|
||||
"for search_type=similarity_score_threshold"
|
||||
)
|
||||
if self.search_type == "similarity":
|
||||
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
|
||||
elif self.search_type == "similarity_score_threshold":
|
||||
@@ -700,6 +710,15 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
if include_score:
|
||||
return [
|
||||
DocumentSearchHit(
|
||||
page_content=doc.page_content,
|
||||
metadata=doc.metadata,
|
||||
score=score,
|
||||
)
|
||||
for doc, score in docs_and_similarities
|
||||
]
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = self.vectorstore.max_marginal_relevance_search(
|
||||
@@ -710,8 +729,17 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
if include_score and self.search_type != "similarity_score_threshold":
|
||||
raise ValueError(
|
||||
"include_score is only supported "
|
||||
"for search_type=similarity_score_threshold"
|
||||
)
|
||||
if self.search_type == "similarity":
|
||||
docs = await self.vectorstore.asimilarity_search(
|
||||
query, **self.search_kwargs
|
||||
@@ -722,6 +750,15 @@ class VectorStoreRetriever(BaseRetriever):
|
||||
query, **self.search_kwargs
|
||||
)
|
||||
)
|
||||
if include_score:
|
||||
return [
|
||||
DocumentSearchHit(
|
||||
page_content=doc.page_content,
|
||||
metadata=doc.metadata,
|
||||
score=score,
|
||||
)
|
||||
for doc, score in docs_and_similarities
|
||||
]
|
||||
docs = [doc for doc, _ in docs_and_similarities]
|
||||
elif self.search_type == "mmr":
|
||||
docs = await self.vectorstore.amax_marginal_relevance_search(
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from langchain_core.documents import __all__
|
||||
|
||||
EXPECTED_ALL = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
|
||||
EXPECTED_ALL = [
|
||||
"Document",
|
||||
"DocumentSearchHit",
|
||||
"BaseDocumentTransformer",
|
||||
"BaseDocumentCompressor",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
|
||||
@@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForRetrieverRun,
|
||||
CallbackManagerForRetrieverRun,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents import Document, DocumentSearchHit
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
@@ -192,19 +192,47 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
return new_query, search_kwargs
|
||||
|
||||
def _get_docs_with_query(
|
||||
self, query: str, search_kwargs: Dict[str, Any]
|
||||
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
|
||||
) -> List[Document]:
|
||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||
if include_score:
|
||||
docs_and_scores = self.vectorstore.similarity_search_with_score(
|
||||
query, **search_kwargs
|
||||
)
|
||||
return [
|
||||
DocumentSearchHit(
|
||||
page_content=doc.page_content, metadata=doc.metadata, score=score
|
||||
)
|
||||
for doc, score in docs_and_scores
|
||||
]
|
||||
else:
|
||||
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
|
||||
return docs
|
||||
|
||||
async def _aget_docs_with_query(
|
||||
self, query: str, search_kwargs: Dict[str, Any]
|
||||
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
|
||||
) -> List[Document]:
|
||||
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
|
||||
if include_score:
|
||||
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
|
||||
query, **search_kwargs
|
||||
)
|
||||
return [
|
||||
DocumentSearchHit(
|
||||
page_content=doc.page_content, metadata=doc.metadata, score=score
|
||||
)
|
||||
for doc, score in docs_and_scores
|
||||
]
|
||||
else:
|
||||
docs = await self.vectorstore.asearch(
|
||||
query, self.search_type, **search_kwargs
|
||||
)
|
||||
return docs
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
@@ -220,11 +248,17 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
docs = self._get_docs_with_query(new_query, search_kwargs)
|
||||
docs = self._get_docs_with_query(
|
||||
new_query, search_kwargs, include_score=include_score
|
||||
)
|
||||
return docs
|
||||
|
||||
async def _aget_relevant_documents(
|
||||
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||||
include_score: bool = False,
|
||||
) -> List[Document]:
|
||||
"""Get documents relevant for a query.
|
||||
|
||||
@@ -240,7 +274,9 @@ class SelfQueryRetriever(BaseRetriever):
|
||||
if self.verbose:
|
||||
logger.info(f"Generated Query: {structured_query}")
|
||||
new_query, search_kwargs = self._prepare_query(query, structured_query)
|
||||
docs = await self._aget_docs_with_query(new_query, search_kwargs)
|
||||
docs = await self._aget_docs_with_query(
|
||||
new_query, search_kwargs, include_score=include_score
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
|
||||
12
libs/langchain/poetry.lock
generated
12
libs/langchain/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiodns"
|
||||
@@ -3497,7 +3497,7 @@ url = "../community"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.51"
|
||||
version = "0.1.52"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@@ -4727,6 +4727,7 @@ description = "Nvidia JIT LTO Library"
|
||||
optional = true
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
|
||||
]
|
||||
@@ -6074,26 +6075,31 @@ python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"},
|
||||
{file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"},
|
||||
{file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"},
|
||||
{file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"},
|
||||
{file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"},
|
||||
{file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"},
|
||||
@@ -9404,4 +9410,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "9ed4d0b11749d1f98e8fbe2895a94e4bc90975817873e52a70f2bbcee934ce19"
|
||||
content-hash = "3e88db5c104ca41c6c320e0d9d0c985430e2c9cf92586c40c175bab88154f91c"
|
||||
|
||||
@@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = "^0.1.48"
|
||||
langchain-core = "^0.1.52"
|
||||
langchain-text-splitters = ">=0.0.1,<0.1"
|
||||
langchain-community = ">=0.0.37,<0.1"
|
||||
langsmith = "^0.1.17"
|
||||
|
||||
Reference in New Issue
Block a user