mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 10:12:33 +00:00
mongodb[patch]: Added scoring threshold to caching (#19286)
## Description Semantic Cache can retrieve noisy information if the score threshold for the value is too low. Adding the ability to set a `score_threshold` on cache construction can allow for less noisy scores to appear. - [x] **Add tests and docs** 1. Added tests that confirm the `score_threshold` query is valid. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
30e4a35d7a
commit
f8078e41e5
@ -217,6 +217,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
|||||||
database_name: str = "default",
|
database_name: str = "default",
|
||||||
index_name: str = "default",
|
index_name: str = "default",
|
||||||
wait_until_ready: bool = False,
|
wait_until_ready: bool = False,
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
**kwargs: Dict[str, Any],
|
**kwargs: Dict[str, Any],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -237,6 +238,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
|||||||
"""
|
"""
|
||||||
client = _generate_mongo_client(connection_string)
|
client = _generate_mongo_client(connection_string)
|
||||||
self.collection = client[database_name][collection_name]
|
self.collection = client[database_name][collection_name]
|
||||||
|
self.score_threshold = score_threshold
|
||||||
self._wait_until_ready = wait_until_ready
|
self._wait_until_ready = wait_until_ready
|
||||||
super().__init__(
|
super().__init__(
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
@ -247,8 +249,17 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
|||||||
|
|
||||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||||
"""Look up based on prompt and llm_string."""
|
"""Look up based on prompt and llm_string."""
|
||||||
|
post_filter_pipeline = (
|
||||||
|
[{"$match": {"score": {"$gte": self.score_threshold}}}]
|
||||||
|
if self.score_threshold
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
search_response = self.similarity_search_with_score(
|
search_response = self.similarity_search_with_score(
|
||||||
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
|
prompt,
|
||||||
|
1,
|
||||||
|
pre_filter={self.LLM: {"$eq": llm_string}},
|
||||||
|
post_filter_pipeline=post_filter_pipeline,
|
||||||
)
|
)
|
||||||
if search_response:
|
if search_response:
|
||||||
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
|
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-mongodb"
|
name = "langchain-mongodb"
|
||||||
version = "0.1.2"
|
version = "0.1.3"
|
||||||
description = "An integration package connecting MongoDB and LangChain"
|
description = "An integration package connecting MongoDB and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -30,6 +30,7 @@ def llm_cache(cls: Any) -> BaseCache:
|
|||||||
collection_name=COLLECTION,
|
collection_name=COLLECTION,
|
||||||
database_name=DATABASE,
|
database_name=DATABASE,
|
||||||
index_name=INDEX_NAME,
|
index_name=INDEX_NAME,
|
||||||
|
score_threshold=0.5,
|
||||||
wait_until_ready=True,
|
wait_until_ready=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -92,13 +93,17 @@ def _execute_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
|
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
|
||||||
|
@pytest.mark.parametrize("remove_score", [True, False])
|
||||||
def test_mongodb_cache(
|
def test_mongodb_cache(
|
||||||
|
remove_score: bool,
|
||||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
||||||
prompt: Union[str, List[BaseMessage]],
|
prompt: Union[str, List[BaseMessage]],
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
llm: Union[str, FakeLLM, FakeChatModel],
|
||||||
response: List[Generation],
|
response: List[Generation],
|
||||||
) -> None:
|
) -> None:
|
||||||
llm_cache(cacher)
|
llm_cache(cacher)
|
||||||
|
if remove_score:
|
||||||
|
get_llm_cache().score_threshold = None # type: ignore
|
||||||
try:
|
try:
|
||||||
_execute_test(prompt, llm, response)
|
_execute_test(prompt, llm, response)
|
||||||
finally:
|
finally:
|
||||||
|
@ -54,6 +54,7 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
|
|||||||
):
|
):
|
||||||
self.collection = MockCollection()
|
self.collection = MockCollection()
|
||||||
self._wait_until_ready = False
|
self._wait_until_ready = False
|
||||||
|
self.score_threshold = None
|
||||||
MongoDBAtlasVectorSearch.__init__(
|
MongoDBAtlasVectorSearch.__init__(
|
||||||
self,
|
self,
|
||||||
self.collection,
|
self.collection,
|
||||||
@ -144,13 +145,17 @@ def _execute_test(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
|
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
|
||||||
)
|
)
|
||||||
|
@pytest.mark.parametrize("remove_score", [True, False])
|
||||||
def test_mongodb_cache(
|
def test_mongodb_cache(
|
||||||
|
remove_score: bool,
|
||||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
||||||
prompt: Union[str, List[BaseMessage]],
|
prompt: Union[str, List[BaseMessage]],
|
||||||
llm: Union[str, FakeLLM, FakeChatModel],
|
llm: Union[str, FakeLLM, FakeChatModel],
|
||||||
response: List[Generation],
|
response: List[Generation],
|
||||||
) -> None:
|
) -> None:
|
||||||
llm_cache(cacher)
|
llm_cache(cacher)
|
||||||
|
if remove_score:
|
||||||
|
get_llm_cache().score_threshold = None # type: ignore
|
||||||
try:
|
try:
|
||||||
_execute_test(prompt, llm, response)
|
_execute_test(prompt, llm, response)
|
||||||
finally:
|
finally:
|
||||||
|
Loading…
Reference in New Issue
Block a user