Compare commits

...

17 Commits

Author SHA1 Message Date
Chester Curme
166ab8217f merge 2024-05-07 10:54:47 -04:00
Chester Curme
46caee5680 Merge branch 'master' into cc/retriever_score 2024-05-07 10:46:08 -04:00
Chester Curme
5fb7d3b4ba propagate metadata 2024-04-26 16:43:35 -04:00
Chester Curme
402298e376 Merge branch 'master' into cc/retriever_score 2024-04-26 15:18:04 -04:00
Chester Curme
267ee9db4c Merge branch 'master' into cc/retriever_score 2024-04-26 15:00:08 -04:00
Chester Curme
bc7af5fd7e bump langchain to 0.1.17rc1 2024-04-26 14:07:11 -04:00
Chester Curme
abf1f4c124 bump core to 0.1.47rc1 2024-04-26 14:05:42 -04:00
Chester Curme
c9fc0447ec Merge branch 'master' into cc/retriever_score 2024-04-26 13:51:50 -04:00
Chester Curme
c262cef1fb update SelfQueryRetriever 2024-04-25 17:32:32 -04:00
Chester Curme
26455d156d Merge branch 'master' into cc/retriever_score 2024-04-25 16:55:38 -04:00
Chester Curme
ceea324071 cr 2024-04-23 16:01:28 -04:00
Chester Curme
1544c9d050 update 2024-04-23 15:06:20 -04:00
Chester Curme
7a78068cd4 fix test 2024-04-23 15:02:02 -04:00
Chester Curme
d91fd8cdcb update 2024-04-23 14:56:17 -04:00
Chester Curme
1f15b0885d add test 2024-04-23 14:34:53 -04:00
Chester Curme
4f16714195 update VectorStoreRetriever 2024-04-23 14:34:46 -04:00
Chester Curme
f8598a7e48 add DocumentSearchHit 2024-04-23 14:34:07 -04:00
10 changed files with 156 additions and 21 deletions

View File

@@ -1444,7 +1444,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
arbitrary_types_allowed = True
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
@@ -1472,7 +1476,11 @@ class RedisVectorStoreRetriever(VectorStoreRetriever):
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(

View File

@@ -1,10 +1,11 @@
import itertools
import random
import uuid
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, cast
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.documents import DocumentSearchHit
from langchain_community.vectorstores import DatabricksVectorSearch
from tests.integration_tests.vectorstores.fake_embeddings import (
@@ -598,6 +599,13 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
assert len(search_result) == len(fake_texts)
else:
assert len(search_result) == 0
result_with_scores = cast(
List[DocumentSearchHit], retriever.invoke(query, include_score=True)
)
for idx, result in enumerate(result_with_scores):
assert result.score >= threshold
assert result.page_content == search_result[idx].page_content
assert result.metadata == search_result[idx].metadata
@pytest.mark.requires("databricks", "databricks.vector_search")

View File

@@ -2,8 +2,13 @@
and their transformations.
"""
from langchain_core.documents.base import Document
from langchain_core.documents.base import Document, DocumentSearchHit
from langchain_core.documents.compressor import BaseDocumentCompressor
from langchain_core.documents.transformers import BaseDocumentTransformer
__all__ = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
__all__ = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]

View File

@@ -30,3 +30,21 @@ class Document(Serializable):
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]
class DocumentSearchHit(Document):
"""Class for storing a document and fields associated with retrieval."""
score: float
"""Score associated with the document's relevance to a query."""
type: Literal["DocumentSearchHit"] = "DocumentSearchHit" # type: ignore[assignment] # noqa: E501
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document_search_hit"]

View File

@@ -157,6 +157,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain", "schema", "document_search_hit", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
@@ -666,6 +672,12 @@ OLD_CORE_NAMESPACES_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"base",
"Document",
),
("langchain_core", "documents", "base", "DocumentSearchHit"): (
"langchain_core",
"documents",
"base",
"DocumentSearchHit",
),
("langchain_core", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",

View File

@@ -39,6 +39,7 @@ from typing import (
TypeVar,
)
from langchain_core.documents import DocumentSearchHit
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@@ -690,8 +691,17 @@ class VectorStoreRetriever(BaseRetriever):
return values
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **self.search_kwargs)
elif self.search_type == "similarity_score_threshold":
@@ -700,6 +710,15 @@ class VectorStoreRetriever(BaseRetriever):
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(
page_content=doc.page_content,
metadata=doc.metadata,
score=score,
)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
@@ -710,8 +729,17 @@ class VectorStoreRetriever(BaseRetriever):
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
if include_score and self.search_type != "similarity_score_threshold":
raise ValueError(
"include_score is only supported "
"for search_type=similarity_score_threshold"
)
if self.search_type == "similarity":
docs = await self.vectorstore.asimilarity_search(
query, **self.search_kwargs
@@ -722,6 +750,15 @@ class VectorStoreRetriever(BaseRetriever):
query, **self.search_kwargs
)
)
if include_score:
return [
DocumentSearchHit(
page_content=doc.page_content,
metadata=doc.metadata,
score=score,
)
for doc, score in docs_and_similarities
]
docs = [doc for doc, _ in docs_and_similarities]
elif self.search_type == "mmr":
docs = await self.vectorstore.amax_marginal_relevance_search(

View File

@@ -1,6 +1,11 @@
from langchain_core.documents import __all__
EXPECTED_ALL = ["Document", "BaseDocumentTransformer", "BaseDocumentCompressor"]
EXPECTED_ALL = [
"Document",
"DocumentSearchHit",
"BaseDocumentTransformer",
"BaseDocumentCompressor",
]
def test_all_imports() -> None:

View File

@@ -33,7 +33,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.documents import Document, DocumentSearchHit
from langchain_core.language_models import BaseLanguageModel
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
@@ -192,19 +192,47 @@ class SelfQueryRetriever(BaseRetriever):
return new_query, search_kwargs
def _get_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = self.vectorstore.similarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(
page_content=doc.page_content, metadata=doc.metadata, score=score
)
for doc, score in docs_and_scores
]
else:
docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
return docs
async def _aget_docs_with_query(
self, query: str, search_kwargs: Dict[str, Any]
self, query: str, search_kwargs: Dict[str, Any], include_score: bool = False
) -> List[Document]:
docs = await self.vectorstore.asearch(query, self.search_type, **search_kwargs)
if include_score:
docs_and_scores = await self.vectorstore.asimilarity_search_with_score(
query, **search_kwargs
)
return [
DocumentSearchHit(
page_content=doc.page_content, metadata=doc.metadata, score=score
)
for doc, score in docs_and_scores
]
else:
docs = await self.vectorstore.asearch(
query, self.search_type, **search_kwargs
)
return docs
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@@ -220,11 +248,17 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = self._get_docs_with_query(new_query, search_kwargs)
docs = self._get_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
include_score: bool = False,
) -> List[Document]:
"""Get documents relevant for a query.
@@ -240,7 +274,9 @@ class SelfQueryRetriever(BaseRetriever):
if self.verbose:
logger.info(f"Generated Query: {structured_query}")
new_query, search_kwargs = self._prepare_query(query, structured_query)
docs = await self._aget_docs_with_query(new_query, search_kwargs)
docs = await self._aget_docs_with_query(
new_query, search_kwargs, include_score=include_score
)
return docs
@classmethod

View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiodns"
@@ -3497,7 +3497,7 @@ url = "../community"
[[package]]
name = "langchain-core"
version = "0.1.51"
version = "0.1.52"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@@ -4727,6 +4727,7 @@ description = "Nvidia JIT LTO Library"
optional = true
python-versions = ">=3"
files = [
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_aarch64.whl", hash = "sha256:75d6498c96d9adb9435f2bbdbddb479805ddfb97b5c1b32395c694185c20ca57"},
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl", hash = "sha256:c6428836d20fe7e327191c175791d38570e10762edc588fb46749217cd444c74"},
{file = "nvidia_nvjitlink_cu12-12.4.99-py3-none-win_amd64.whl", hash = "sha256:991905ffa2144cb603d8ca7962d75c35334ae82bf92820b6ba78157277da1ad2"},
]
@@ -6074,26 +6075,31 @@ python-versions = ">=3.8"
files = [
{file = "PyMuPDF-1.23.26-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:645a05321aecc8c45739f71f0eb574ce33138d19189582ffa5241fea3a8e2549"},
{file = "PyMuPDF-1.23.26-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2dfc9e010669ae92fade6fb72aaea49ebe3b8dcd7ee4dcbbe50115abcaa4d3fe"},
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_aarch64.whl", hash = "sha256:734ee380b3abd038602be79114194a3cb74ac102b7c943bcb333104575922c50"},
{file = "PyMuPDF-1.23.26-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:b22f8d854f8196ad5b20308c1cebad3d5189ed9f0988acbafa043947ea7e6c55"},
{file = "PyMuPDF-1.23.26-cp310-none-win32.whl", hash = "sha256:cc0f794e3466bc96b5bf79d42fbc1551428751e3fef38ebc10ac70396b676144"},
{file = "PyMuPDF-1.23.26-cp310-none-win_amd64.whl", hash = "sha256:2eb701247d8e685a24e45899d1175f01a3ce5fc792a4431c91fbb68633b29298"},
{file = "PyMuPDF-1.23.26-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:e2804a64bb57da414781e312fb0561f6be67658ad57ed4a73dce008b23fc70a6"},
{file = "PyMuPDF-1.23.26-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:97b40bb22e3056874634617a90e0ed24a5172cf71791b9e25d1d91c6743bc567"},
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_aarch64.whl", hash = "sha256:fab8833559bc47ab26ce736f915b8fc1dd37c108049b90396f7cd5e1004d7593"},
{file = "PyMuPDF-1.23.26-cp311-none-manylinux2014_x86_64.whl", hash = "sha256:f25aafd3e7fb9d7761a22acf2b67d704f04cc36d4dc33a3773f0eb3f4ec3606f"},
{file = "PyMuPDF-1.23.26-cp311-none-win32.whl", hash = "sha256:05e672ed3e82caca7ef02a88ace30130b1dd392a1190f03b2b58ffe7aa331400"},
{file = "PyMuPDF-1.23.26-cp311-none-win_amd64.whl", hash = "sha256:92b3c4dd4d0491d495f333be2d41f4e1c155a409bc9d04b5ff29655dccbf4655"},
{file = "PyMuPDF-1.23.26-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:a217689ede18cc6991b4e6a78afee8a440b3075d53b9dec4ba5ef7487d4547e9"},
{file = "PyMuPDF-1.23.26-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:42ad2b819b90ce1947e11b90ec5085889df0a2e3aa0207bc97ecacfc6157cabc"},
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_aarch64.whl", hash = "sha256:99607649f89a02bba7d8ebe96e2410664316adc95e9337f7dfeff6a154f93049"},
{file = "PyMuPDF-1.23.26-cp312-none-manylinux2014_x86_64.whl", hash = "sha256:bb42d4b8407b4de7cb58c28f01449f16f32a6daed88afb41108f1aeb3552bdd4"},
{file = "PyMuPDF-1.23.26-cp312-none-win32.whl", hash = "sha256:c40d044411615e6f0baa7d3d933b3032cf97e168c7fa77d1be8a46008c109aee"},
{file = "PyMuPDF-1.23.26-cp312-none-win_amd64.whl", hash = "sha256:3f876533aa7f9a94bcd9a0225ce72571b7808260903fec1d95c120bc842fb52d"},
{file = "PyMuPDF-1.23.26-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:52df831d46beb9ff494f5fba3e5d069af6d81f49abf6b6e799ee01f4f8fa6799"},
{file = "PyMuPDF-1.23.26-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:0bbb0cf6593e53524f3fc26fb5e6ead17c02c64791caec7c4afe61b677dedf80"},
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_aarch64.whl", hash = "sha256:5ef4360f20015673c20cf59b7e19afc97168795188c584254ed3778cde43ce77"},
{file = "PyMuPDF-1.23.26-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:d7cd88842b2e7f4c71eef4d87c98c35646b80b60e6375392d7ce40e519261f59"},
{file = "PyMuPDF-1.23.26-cp38-none-win32.whl", hash = "sha256:6577e2f473625e2d0df5f5a3bf1e4519e94ae749733cc9937994d1b256687bfa"},
{file = "PyMuPDF-1.23.26-cp38-none-win_amd64.whl", hash = "sha256:fbe1a3255b2cd0d769b2da2c4efdd0c0f30d4961a1aac02c0f75cf951b337aa4"},
{file = "PyMuPDF-1.23.26-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:73fce034f2afea886a59ead2d0caedf27e2b2a8558b5da16d0286882e0b1eb82"},
{file = "PyMuPDF-1.23.26-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:b3de8618b7cb5b36db611083840b3bcf09b11a893e2d8262f4e042102c7e65de"},
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_aarch64.whl", hash = "sha256:879e7f5ad35709d8760ab6103c3d5dac8ab8043a856ab3653fd324af7358ee87"},
{file = "PyMuPDF-1.23.26-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:deee96c2fd415ded7b5070d8d5b2c60679aee6ed0e28ac0d2cb998060d835c2c"},
{file = "PyMuPDF-1.23.26-cp39-none-win32.whl", hash = "sha256:9f7f4ef99dd8ac97fb0b852efa3dcbee515798078b6c79a6a13c7b1e7c5d41a4"},
{file = "PyMuPDF-1.23.26-cp39-none-win_amd64.whl", hash = "sha256:ba9a54552c7afb9ec85432c765e2fa9a81413acfaa7d70db7c9b528297749e5b"},
@@ -9404,4 +9410,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "9ed4d0b11749d1f98e8fbe2895a94e4bc90975817873e52a70f2bbcee934ce19"
content-hash = "3e88db5c104ca41c6c320e0d9d0c985430e2c9cf92586c40c175bab88154f91c"

View File

@@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.48"
langchain-core = "^0.1.52"
langchain-text-splitters = ">=0.0.1,<0.1"
langchain-community = ">=0.0.37,<0.1"
langsmith = "^0.1.17"