mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
langchain[minor]: Databricks vector search self query integration (#20627)
- Enable self querying feature for databricks vector search --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from langchain_community.vectorstores import (
|
||||
AstraDB,
|
||||
Chroma,
|
||||
DashVector,
|
||||
DatabricksVectorSearch,
|
||||
DeepLake,
|
||||
Dingo,
|
||||
Milvus,
|
||||
@@ -43,6 +44,9 @@ from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.retrievers.self_query.astradb import AstraDBTranslator
|
||||
from langchain.retrievers.self_query.chroma import ChromaTranslator
|
||||
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
|
||||
from langchain.retrievers.self_query.databricks_vector_search import (
|
||||
DatabricksVectorSearchTranslator,
|
||||
)
|
||||
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
|
||||
from langchain.retrievers.self_query.dingo import DingoDBTranslator
|
||||
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
|
||||
@@ -85,7 +89,8 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
OpenSearchVectorSearch: OpenSearchTranslator,
|
||||
MongoDBAtlasVectorSearch: MongoDBAtlasTranslator,
|
||||
}
|
||||
|
||||
if isinstance(vectorstore, DatabricksVectorSearch):
|
||||
return DatabricksVectorSearchTranslator()
|
||||
if isinstance(vectorstore, Qdrant):
|
||||
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
|
||||
elif isinstance(vectorstore, MyScale):
|
||||
|
@@ -0,0 +1,90 @@
|
||||
from collections import ChainMap
|
||||
from itertools import chain
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
|
||||
_COMPARATOR_TO_SYMBOL = {
|
||||
Comparator.EQ: "",
|
||||
Comparator.GT: " >",
|
||||
Comparator.GTE: " >=",
|
||||
Comparator.LT: " <",
|
||||
Comparator.LTE: " <=",
|
||||
Comparator.IN: "",
|
||||
Comparator.LIKE: " LIKE",
|
||||
}
|
||||
|
||||
|
||||
class DatabricksVectorSearchTranslator(Visitor):
|
||||
"""Translate `Databricks vector search` internal query language elements to
|
||||
valid filters."""
|
||||
|
||||
"""Subset of allowed logical operators."""
|
||||
allowed_operators = [Operator.AND, Operator.NOT, Operator.OR]
|
||||
|
||||
"""Subset of allowed logical comparators."""
|
||||
allowed_comparators = [
|
||||
Comparator.EQ,
|
||||
Comparator.GT,
|
||||
Comparator.GTE,
|
||||
Comparator.LT,
|
||||
Comparator.LTE,
|
||||
Comparator.IN,
|
||||
Comparator.LIKE,
|
||||
]
|
||||
|
||||
def _visit_and_operation(self, operation: Operation) -> Dict:
|
||||
return dict(ChainMap(*[arg.accept(self) for arg in operation.arguments]))
|
||||
|
||||
def _visit_or_operation(self, operation: Operation) -> Dict:
|
||||
filter_args = [arg.accept(self) for arg in operation.arguments]
|
||||
flattened_args = list(
|
||||
chain.from_iterable(filter_arg.items() for filter_arg in filter_args)
|
||||
)
|
||||
return {
|
||||
" OR ".join(key for key, _ in flattened_args): [
|
||||
value for _, value in flattened_args
|
||||
]
|
||||
}
|
||||
|
||||
def _visit_not_operation(self, operation: Operation) -> Dict:
|
||||
if len(operation.arguments) > 1:
|
||||
raise ValueError(
|
||||
f'"{operation.operator.value}" can have only one argument '
|
||||
f"in Databricks vector search"
|
||||
)
|
||||
filter_arg = operation.arguments[0].accept(self)
|
||||
return {
|
||||
f"{colum_with_bool_expression} NOT": value
|
||||
for colum_with_bool_expression, value in filter_arg.items()
|
||||
}
|
||||
|
||||
def visit_operation(self, operation: Operation) -> Dict:
|
||||
self._validate_func(operation.operator)
|
||||
if operation.operator == Operator.AND:
|
||||
return self._visit_and_operation(operation)
|
||||
elif operation.operator == Operator.OR:
|
||||
return self._visit_or_operation(operation)
|
||||
elif operation.operator == Operator.NOT:
|
||||
return self._visit_not_operation(operation)
|
||||
|
||||
def visit_comparison(self, comparison: Comparison) -> Dict:
|
||||
self._validate_func(comparison.comparator)
|
||||
comparator_symbol = _COMPARATOR_TO_SYMBOL[comparison.comparator]
|
||||
return {f"{comparison.attribute}{comparator_symbol}": comparison.value}
|
||||
|
||||
def visit_structured_query(
|
||||
self, structured_query: StructuredQuery
|
||||
) -> Tuple[str, dict]:
|
||||
if structured_query.filter is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"filters": structured_query.filter.accept(self)}
|
||||
return structured_query.query, kwargs
|
@@ -0,0 +1,141 @@
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
Comparator,
|
||||
Comparison,
|
||||
Operation,
|
||||
Operator,
|
||||
StructuredQuery,
|
||||
)
|
||||
from langchain.retrievers.self_query.databricks_vector_search import (
|
||||
DatabricksVectorSearchTranslator,
|
||||
)
|
||||
|
||||
DEFAULT_TRANSLATOR = DatabricksVectorSearchTranslator()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"triplet",
|
||||
[
|
||||
(Comparator.EQ, 2, {"foo": 2}),
|
||||
(Comparator.GT, 2, {"foo >": 2}),
|
||||
(Comparator.GTE, 2, {"foo >=": 2}),
|
||||
(Comparator.LT, 2, {"foo <": 2}),
|
||||
(Comparator.LTE, 2, {"foo <=": 2}),
|
||||
(Comparator.IN, ["bar", "abc"], {"foo": ["bar", "abc"]}),
|
||||
(Comparator.LIKE, "bar", {"foo LIKE": "bar"}),
|
||||
],
|
||||
)
|
||||
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
|
||||
comparator, value, expected = triplet
|
||||
comp = Comparison(comparator=comparator, attribute="foo", value=value)
|
||||
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation_and() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
expected = {"foo <": 2, "bar": "baz"}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation_or() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.OR,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
expected = {"foo OR bar": [2, "baz"]}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation_not() -> None:
|
||||
op = Operation(
|
||||
operator=Operator.NOT,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
],
|
||||
)
|
||||
expected = {"foo NOT": 2}
|
||||
actual = DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_operation_not_that_raises_for_more_than_one_filter_condition() -> None:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
op = Operation(
|
||||
operator=Operator.NOT,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
|
||||
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
|
||||
],
|
||||
)
|
||||
DEFAULT_TRANSLATOR.visit_operation(op)
|
||||
assert (
|
||||
str(exc_info.value) == '"not" can have only one argument in '
|
||||
"Databricks vector search"
|
||||
)
|
||||
|
||||
|
||||
def test_visit_structured_query_with_no_filter() -> None:
|
||||
query = "What is the capital of France?"
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=None,
|
||||
)
|
||||
expected: Tuple[str, Dict] = (query, {})
|
||||
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query_with_one_arg_filter() -> None:
|
||||
query = "What is the capital of France?"
|
||||
comp = Comparison(comparator=Comparator.EQ, attribute="country", value="France")
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=comp,
|
||||
)
|
||||
|
||||
expected = (query, {"filters": {"country": "France"}})
|
||||
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
||||
|
||||
|
||||
def test_visit_structured_query_with_multiple_arg_filter_and_operator() -> None:
|
||||
query = "What is the capital of France in the years between 1888 and 1900?"
|
||||
|
||||
op = Operation(
|
||||
operator=Operator.AND,
|
||||
arguments=[
|
||||
Comparison(comparator=Comparator.EQ, attribute="country", value="France"),
|
||||
Comparison(comparator=Comparator.GTE, attribute="year", value=1888),
|
||||
Comparison(comparator=Comparator.LTE, attribute="year", value=1900),
|
||||
],
|
||||
)
|
||||
|
||||
structured_query = StructuredQuery(
|
||||
query=query,
|
||||
filter=op,
|
||||
)
|
||||
|
||||
expected = (
|
||||
query,
|
||||
{"filters": {"country": "France", "year >=": 1888, "year <=": 1900}},
|
||||
)
|
||||
|
||||
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
|
||||
assert expected == actual
|
Reference in New Issue
Block a user