mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
mongodb[patch]: Added scoring threshold to caching (#19286)
## Description Semantic Cache can retrieve noisy information if the score threshold for the value is too low. Adding the ability to set a `score_threshold` on cache construction can allow for less noisy scores to appear. - [x] **Add tests and docs** 1. Added tests that confirm the `score_threshold` query is valid. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
30e4a35d7a
commit
f8078e41e5
@ -217,6 +217,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
database_name: str = "default",
|
||||
index_name: str = "default",
|
||||
wait_until_ready: bool = False,
|
||||
score_threshold: Optional[float] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
@ -237,6 +238,7 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
"""
|
||||
client = _generate_mongo_client(connection_string)
|
||||
self.collection = client[database_name][collection_name]
|
||||
self.score_threshold = score_threshold
|
||||
self._wait_until_ready = wait_until_ready
|
||||
super().__init__(
|
||||
collection=self.collection,
|
||||
@ -247,8 +249,17 @@ class MongoDBAtlasSemanticCache(BaseCache, MongoDBAtlasVectorSearch):
|
||||
|
||||
def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]:
|
||||
"""Look up based on prompt and llm_string."""
|
||||
post_filter_pipeline = (
|
||||
[{"$match": {"score": {"$gte": self.score_threshold}}}]
|
||||
if self.score_threshold
|
||||
else None
|
||||
)
|
||||
|
||||
search_response = self.similarity_search_with_score(
|
||||
prompt, 1, pre_filter={self.LLM: {"$eq": llm_string}}
|
||||
prompt,
|
||||
1,
|
||||
pre_filter={self.LLM: {"$eq": llm_string}},
|
||||
post_filter_pipeline=post_filter_pipeline,
|
||||
)
|
||||
if search_response:
|
||||
return_val = search_response[0][0].metadata.get(self.RETURN_VAL)
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-mongodb"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
description = "An integration package connecting MongoDB and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
|
@ -30,6 +30,7 @@ def llm_cache(cls: Any) -> BaseCache:
|
||||
collection_name=COLLECTION,
|
||||
database_name=DATABASE,
|
||||
index_name=INDEX_NAME,
|
||||
score_threshold=0.5,
|
||||
wait_until_ready=True,
|
||||
)
|
||||
)
|
||||
@ -92,13 +93,17 @@ def _execute_test(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("cacher", [MongoDBCache, MongoDBAtlasSemanticCache])
|
||||
@pytest.mark.parametrize("remove_score", [True, False])
|
||||
def test_mongodb_cache(
|
||||
remove_score: bool,
|
||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
||||
prompt: Union[str, List[BaseMessage]],
|
||||
llm: Union[str, FakeLLM, FakeChatModel],
|
||||
response: List[Generation],
|
||||
) -> None:
|
||||
llm_cache(cacher)
|
||||
if remove_score:
|
||||
get_llm_cache().score_threshold = None # type: ignore
|
||||
try:
|
||||
_execute_test(prompt, llm, response)
|
||||
finally:
|
||||
|
@ -54,6 +54,7 @@ class PatchedMongoDBAtlasSemanticCache(MongoDBAtlasSemanticCache):
|
||||
):
|
||||
self.collection = MockCollection()
|
||||
self._wait_until_ready = False
|
||||
self.score_threshold = None
|
||||
MongoDBAtlasVectorSearch.__init__(
|
||||
self,
|
||||
self.collection,
|
||||
@ -144,13 +145,17 @@ def _execute_test(
|
||||
@pytest.mark.parametrize(
|
||||
"cacher", [PatchedMongoDBCache, PatchedMongoDBAtlasSemanticCache]
|
||||
)
|
||||
@pytest.mark.parametrize("remove_score", [True, False])
|
||||
def test_mongodb_cache(
|
||||
remove_score: bool,
|
||||
cacher: Union[MongoDBCache, MongoDBAtlasSemanticCache],
|
||||
prompt: Union[str, List[BaseMessage]],
|
||||
llm: Union[str, FakeLLM, FakeChatModel],
|
||||
response: List[Generation],
|
||||
) -> None:
|
||||
llm_cache(cacher)
|
||||
if remove_score:
|
||||
get_llm_cache().score_threshold = None # type: ignore
|
||||
try:
|
||||
_execute_test(prompt, llm, response)
|
||||
finally:
|
||||
|
Loading…
Reference in New Issue
Block a user