langchain[minor]: Databricks vector search self query integration (#20627)

- Enable self querying feature for databricks vector search

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
Sivaudha
2024-04-19 05:44:38 +02:00
committed by GitHub
parent 6d530481c1
commit baedc3ec0a
4 changed files with 785 additions and 1 deletions

View File

@@ -7,6 +7,7 @@ from langchain_community.vectorstores import (
AstraDB,
Chroma,
DashVector,
DatabricksVectorSearch,
DeepLake,
Dingo,
Milvus,
@@ -43,6 +44,9 @@ from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers.self_query.astradb import AstraDBTranslator
from langchain.retrievers.self_query.chroma import ChromaTranslator
from langchain.retrievers.self_query.dashvector import DashvectorTranslator
from langchain.retrievers.self_query.databricks_vector_search import (
DatabricksVectorSearchTranslator,
)
from langchain.retrievers.self_query.deeplake import DeepLakeTranslator
from langchain.retrievers.self_query.dingo import DingoDBTranslator
from langchain.retrievers.self_query.elasticsearch import ElasticsearchTranslator
@@ -85,7 +89,8 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
OpenSearchVectorSearch: OpenSearchTranslator,
MongoDBAtlasVectorSearch: MongoDBAtlasTranslator,
}
if isinstance(vectorstore, DatabricksVectorSearch):
return DatabricksVectorSearchTranslator()
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):

View File

@@ -0,0 +1,90 @@
from collections import ChainMap
from itertools import chain
from typing import Dict, Tuple
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
_COMPARATOR_TO_SYMBOL = {
Comparator.EQ: "",
Comparator.GT: " >",
Comparator.GTE: " >=",
Comparator.LT: " <",
Comparator.LTE: " <=",
Comparator.IN: "",
Comparator.LIKE: " LIKE",
}
class DatabricksVectorSearchTranslator(Visitor):
"""Translate `Databricks vector search` internal query language elements to
valid filters."""
"""Subset of allowed logical operators."""
allowed_operators = [Operator.AND, Operator.NOT, Operator.OR]
"""Subset of allowed logical comparators."""
allowed_comparators = [
Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.IN,
Comparator.LIKE,
]
def _visit_and_operation(self, operation: Operation) -> Dict:
return dict(ChainMap(*[arg.accept(self) for arg in operation.arguments]))
def _visit_or_operation(self, operation: Operation) -> Dict:
filter_args = [arg.accept(self) for arg in operation.arguments]
flattened_args = list(
chain.from_iterable(filter_arg.items() for filter_arg in filter_args)
)
return {
" OR ".join(key for key, _ in flattened_args): [
value for _, value in flattened_args
]
}
def _visit_not_operation(self, operation: Operation) -> Dict:
if len(operation.arguments) > 1:
raise ValueError(
f'"{operation.operator.value}" can have only one argument '
f"in Databricks vector search"
)
filter_arg = operation.arguments[0].accept(self)
return {
f"{colum_with_bool_expression} NOT": value
for colum_with_bool_expression, value in filter_arg.items()
}
def visit_operation(self, operation: Operation) -> Dict:
self._validate_func(operation.operator)
if operation.operator == Operator.AND:
return self._visit_and_operation(operation)
elif operation.operator == Operator.OR:
return self._visit_or_operation(operation)
elif operation.operator == Operator.NOT:
return self._visit_not_operation(operation)
def visit_comparison(self, comparison: Comparison) -> Dict:
self._validate_func(comparison.comparator)
comparator_symbol = _COMPARATOR_TO_SYMBOL[comparison.comparator]
return {f"{comparison.attribute}{comparator_symbol}": comparison.value}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filters": structured_query.filter.accept(self)}
return structured_query.query, kwargs

View File

@@ -0,0 +1,141 @@
from typing import Any, Dict, Tuple
import pytest
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain.retrievers.self_query.databricks_vector_search import (
DatabricksVectorSearchTranslator,
)
DEFAULT_TRANSLATOR = DatabricksVectorSearchTranslator()
@pytest.mark.parametrize(
"triplet",
[
(Comparator.EQ, 2, {"foo": 2}),
(Comparator.GT, 2, {"foo >": 2}),
(Comparator.GTE, 2, {"foo >=": 2}),
(Comparator.LT, 2, {"foo <": 2}),
(Comparator.LTE, 2, {"foo <=": 2}),
(Comparator.IN, ["bar", "abc"], {"foo": ["bar", "abc"]}),
(Comparator.LIKE, "bar", {"foo LIKE": "bar"}),
],
)
def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None:
comparator, value, expected = triplet
comp = Comparison(comparator=comparator, attribute="foo", value=value)
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation_and() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = {"foo <": 2, "bar": "baz"}
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
def test_visit_operation_or() -> None:
op = Operation(
operator=Operator.OR,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
expected = {"foo OR bar": [2, "baz"]}
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
def test_visit_operation_not() -> None:
op = Operation(
operator=Operator.NOT,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
],
)
expected = {"foo NOT": 2}
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
def test_visit_operation_not_that_raises_for_more_than_one_filter_condition() -> None:
with pytest.raises(Exception) as exc_info:
op = Operation(
operator=Operator.NOT,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
],
)
DEFAULT_TRANSLATOR.visit_operation(op)
assert (
str(exc_info.value) == '"not" can have only one argument in '
"Databricks vector search"
)
def test_visit_structured_query_with_no_filter() -> None:
query = "What is the capital of France?"
structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
def test_visit_structured_query_with_one_arg_filter() -> None:
query = "What is the capital of France?"
comp = Comparison(comparator=Comparator.EQ, attribute="country", value="France")
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected = (query, {"filters": {"country": "France"}})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
def test_visit_structured_query_with_multiple_arg_filter_and_operator() -> None:
query = "What is the capital of France in the years between 1888 and 1900?"
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.EQ, attribute="country", value="France"),
Comparison(comparator=Comparator.GTE, attribute="year", value=1888),
Comparison(comparator=Comparator.LTE, attribute="year", value=1900),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected = (
query,
{"filters": {"country": "France", "year >=": 1888, "year <=": 1900}},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual