community[patch], langchain[minor]: Add retriever self_query and score_threshold in DingoDB (#18106)

This commit is contained in:
Hech 2024-03-06 07:47:29 +08:00 committed by GitHub
parent d039dcb6ba
commit 6a08134661
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 656 additions and 3 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -145,7 +145,7 @@ class Dingo(VectorStore):
List of Documents most similar to the query and score for each
"""
docs_and_scores = self.similarity_search_with_score(
query, k=k, search_params=search_params
query, k=k, search_params=search_params, **kwargs
)
return [doc for doc, _ in docs_and_scores]
@ -177,9 +177,15 @@ class Dingo(VectorStore):
return []
for res in results[0]["vectorWithDistances"]:
score = res["distance"]
if (
"score_threshold" in kwargs
and kwargs.get("score_threshold") is not None
):
if score > kwargs.get("score_threshold"):
continue
metadatas = res["scalarData"]
id = res["id"]
score = res["distance"]
text = metadatas[self._text_key]["fields"][0]["data"]
metadata = {"id": id, "text": text, "score": score}
for meta_key in metadatas.keys():

View File

@ -7,6 +7,7 @@ from langchain_community.vectorstores import (
Chroma,
DashVector,
DeepLake,
Dingo,
ElasticsearchStore,
Milvus,
MongoDBAtlasVectorSearch,
@ -39,6 +40,7 @@ from langchain.retrievers.self_query.astradb import AstraDBTranslator
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.dingo import DingoDBTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
from langchain.retrievers.self_query.milvus import MilvusTranslator
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
@ -65,6 +67,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Pinecone: PineconeTranslator,
Chroma: ChromaTranslator,
DashVector: DashvectorTranslator,
Dingo: DingoDBTranslator,
Weaviate: WeaviateTranslator,
Vectara: VectaraTranslator,
Qdrant: QdrantTranslator,

View File

@ -0,0 +1,49 @@
from typing import Tuple, Union
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class DingoDBTranslator(Visitor):
"""Translate `DingoDB` internal query language elements to valid filters."""
allowed_comparators = (
Comparator.EQ,
Comparator.NE,
Comparator.LT,
Comparator.LTE,
Comparator.GT,
Comparator.GTE,
)
"""Subset of allowed logical comparators."""
allowed_operators = (Operator.AND, Operator.OR)
"""Subset of allowed logical operators."""
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Operation:
return operation
def visit_comparison(self, comparison: Comparison) -> Comparison:
return comparison
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {
"search_params": {
"langchain_expr": structured_query.filter.accept(self)
}
}
return structured_query.query, kwargs

View File

@ -0,0 +1,99 @@
from typing import Dict, Tuple
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain.retrievers.self_query.dingo import DingoDBTranslator
DEFAULT_TRANSLATOR = DingoDBTranslator()
def test_visit_comparison() -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
expected = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
def test_visit_structured_query() -> None:
query = "What is the capital of France?"
structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Comparison(
comparator=Comparator.LT, attribute="foo", value=["1", "2"]
)
}
},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected = (
query,
{
"search_params": {
"langchain_expr": Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(
comparator=Comparator.EQ, attribute="bar", value="baz"
),
],
)
}
},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual