From 859748419577b9a3a5935524732fb8c254d08ad8 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 18 Jan 2024 01:41:23 +0800 Subject: [PATCH] langchain[patch]: support more comparators in Milvus self-querying retriever (#16076) - **Description:** Support IN and LIKE comparators in Milvus self-querying retriever, based on [Boolean Expression Rules](https://milvus.io/docs/boolean.md) - **Issue:** No - **Dependencies:** No - **Twitter handle:** No Signed-off-by: ChengZi --- .../langchain/retrievers/self_query/milvus.py | 20 ++++++++++++---- .../retrievers/self_query/test_milvus.py | 23 +++++++++++++++---- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/retrievers/self_query/milvus.py b/libs/langchain/langchain/retrievers/self_query/milvus.py index f855bf99093..dbc61f6f712 100644 --- a/libs/langchain/langchain/retrievers/self_query/milvus.py +++ b/libs/langchain/langchain/retrievers/self_query/milvus.py @@ -16,28 +16,36 @@ COMPARATOR_TO_BER = { Comparator.GTE: ">=", Comparator.LT: "<", Comparator.LTE: "<=", + Comparator.IN: "in", + Comparator.LIKE: "like", } UNARY_OPERATORS = [Operator.NOT] -def process_value(value: Union[int, float, str]) -> str: +def process_value(value: Union[int, float, str], comparator: Comparator) -> str: """Convert a value to a string and add double quotes if it is a string. It required for comparators involving strings. Args: value: The value to convert. + comparator: The comparator. Returns: The converted value as a string. """ # if isinstance(value, str): - # If the value is already a string, add double quotes - return f'"{value}"' + if comparator is Comparator.LIKE: + # If the comparator is LIKE, add a percent sign after it for prefix matching + # and add double quotes + return f'"{value}%"' + else: + # If the value is already a string, add double quotes + return f'"{value}"' else: - # If the valueis not a string, convert it to a string without double quotes + # If the value is not a string, convert it to a string without double quotes return str(value) @@ -54,6 +62,8 @@ class MilvusTranslator(Visitor): Comparator.GTE, Comparator.LT, Comparator.LTE, + Comparator.IN, + Comparator.LIKE, ] def _format_func(self, func: Union[Operator, Comparator]) -> str: @@ -78,7 +88,7 @@ class MilvusTranslator(Visitor): def visit_comparison(self, comparison: Comparison) -> str: comparator = self._format_func(comparison.comparator) - processed_value = process_value(comparison.value) + processed_value = process_value(comparison.value, comparison.comparator) attribute = comparison.attribute return "( " + attribute + " " + comparator + " " + processed_value + " )" diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py index d4449744044..4a96c18e274 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_milvus.py @@ -1,4 +1,6 @@ -from typing import Dict, Tuple +from typing import Any, Dict, Tuple + +import pytest from langchain.chains.query_constructor.ir import ( Comparator, @@ -12,11 +14,22 @@ from langchain.retrievers.self_query.milvus import MilvusTranslator DEFAULT_TRANSLATOR = MilvusTranslator() -def test_visit_comparison() -> None: - comp = Comparison(comparator=Comparator.LT, attribute="foo", value=4) - expected = "( foo < 4 )" +@pytest.mark.parametrize( + "triplet", + [ + (Comparator.EQ, 2, "( foo == 2 )"), + (Comparator.GT, 2, "( foo > 2 )"), + (Comparator.GTE, 2, "( foo >= 2 )"), + (Comparator.LT, 2, "( foo < 2 )"), + (Comparator.LTE, 2, "( foo <= 2 )"), + (Comparator.IN, ["bar", "abc"], "( foo in ['bar', 'abc'] )"), + (Comparator.LIKE, "bar", '( foo like "bar%" )'), + ], +) +def test_visit_comparison(triplet: Tuple[Comparator, Any, str]) -> None: + comparator, value, expected = triplet + comp = Comparison(comparator=comparator, attribute="foo", value=value) actual = DEFAULT_TRANSLATOR.visit_comparison(comp) - assert expected == actual