mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
community[patch], langchain[minor]: Add retriever self_query and score_threshold in DingoDB (#18106)
This commit is contained in:
parent
d039dcb6ba
commit
6a08134661
File diff suppressed because one or more lines are too long
496
docs/docs/integrations/retrievers/self_query/dingo.ipynb
Normal file
496
docs/docs/integrations/retrievers/self_query/dingo.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -145,7 +145,7 @@ class Dingo(VectorStore):
|
||||
List of Documents most similar to the query and score for each
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k=k, search_params=search_params
|
||||
query, k=k, search_params=search_params, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
@ -177,9 +177,15 @@ class Dingo(VectorStore):
|
||||
return []
|
||||
|
||||
for res in results[0]["vectorWithDistances"]:
|
||||
score = res["distance"]
|
||||
if (
|
||||
"score_threshold" in kwargs
|
||||
and kwargs.get("score_threshold") is not None
|
||||
):
|
||||
if score > kwargs.get("score_threshold"):
|
||||
continue
|
||||
metadatas = res["scalarData"]
|
||||
id = res["id"]
|
||||
score = res["distance"]
|
||||
text = metadatas[self._text_key]["fields"][0]["data"]
|
||||
metadata = {"id": id, "text": text, "score": score}
|
||||
for meta_key in metadatas.keys():
|
||||
|
@ -7,6 +7,7 @@ from langchain_community.vectorstores import (
|
||||
Chroma,
|
||||
DashVector,
|
||||
DeepLake,
|
||||
Dingo,
|
||||
ElasticsearchStore,
|
||||
Milvus,
|
||||
MongoDBAtlasVectorSearch,
|
||||
@ -39,6 +40,7 @@ from langchain.retrievers.self_query.astradb import AstraDBTranslator
|
||||
from langchain.retrievers.self_query.chroma import ChromaTranslator
|
||||
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
|
||||
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
|
||||
from langchain.retrievers.self_query.dingo import DingoDBTranslator
|
||||
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
|
||||
from langchain.retrievers.self_query.milvus import MilvusTranslator
|
||||
from langchain.retrievers.self_query.mongodb_atlas import MongoDBAtlasTranslator
|
||||
@ -65,6 +67,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
Pinecone: PineconeTranslator,
|
||||
Chroma: ChromaTranslator,
|
||||
DashVector: DashvectorTranslator,
|
||||
Dingo: DingoDBTranslator,
|
||||
Weaviate: WeaviateTranslator,
|
||||
Vectara: VectaraTranslator,
|
||||
Qdrant: QdrantTranslator,
|
||||
|
49
libs/langchain/langchain/retrievers/self_query/dingo.py
Normal file
49
libs/langchain/langchain/retrievers/self_query/dingo.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
|
||||
|
||||
class DingoDBTranslator(Visitor):
|
||||
"""Translate `DingoDB` internal query language elements to valid filters."""
|
||||
|
||||
allowed_comparators = (
|
||||
Comparator.EQ,
|
||||
Comparator.NE,
|
||||
Comparator.LT,
|
||||
Comparator.LTE,
|
||||
Comparator.GT,
|
||||
Comparator.GTE,
|
||||
)
|
||||
"""Subset of allowed logical comparators."""
|
||||
allowed_operators = (Operator.AND, Operator.OR)
|
||||
"""Subset of allowed logical operators."""
|
||||
|
||||
def _format_func(self, func: Union[Operator, Comparator]) -> str:
|
||||
self._validate_func(func)
|
||||
return f"${func.value}"
|
||||
|
||||
def visit_operation(self, operation: Operation) -> Operation:
|
||||
return operation
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> Comparison:
|
||||
return comparison
|
||||
|
||||
def visit_structured_query(
|
||||
self, structured_query: StructuredQuery
|
||||
) -> Tuple[str, dict]:
|
||||
if structured_query.filter is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {
|
||||
"search_params": {
|
||||
"langchain_expr": structured_query.filter.accept(self)
|
||||
}
|
||||
}
|
||||
return structured_query.query, kwargs
|
@ -0,0 +1,99 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.dingo import DingoDBTranslator
|
||||
|
||||
DEFAULT_TRANSLATOR = DingoDBTranslator()
|
||||
|
||||
|
||||
def test_visit_comparison() -> None:
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
expected = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
expected = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query() -> None:
|
||||
query = "What is the capital of France?"
|
||||
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=["1", "2"])
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{
|
||||
"search_params": {
|
||||
"langchain_expr": Comparison(
|
||||
comparator=Comparator.LT, attribute="foo", value=["1", "2"]
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
expected = (
|
||||
query,
|
||||
{
|
||||
"search_params": {
|
||||
"langchain_expr": Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(
|
||||
comparator=Comparator.EQ, attribute="bar", value="baz"
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
Loading…
Reference in New Issue
Block a user