community[minor]: Self query retriever for HANA Cloud Vector Engine (#24494)

Description:

- This PR adds a self query retriever implementation for SAP HANA Cloud
Vector Engine. The retriever supports all operators except for contains.
- Issue: N/A
- Dependencies: no new dependencies added

**Add tests and docs:**
Added integration tests to:
libs/community/tests/unit_tests/query_constructors/test_hanavector.py

**Documentation for self query retriever:**
/docs/integrations/retrievers/self_query/hanavector_self_query.ipynb

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
yonarw 2024-07-26 08:56:51 +02:00 committed by GitHub
parent 4f3b4fc7fe
commit b65ac8d39c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 408 additions and 2 deletions

View File

@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SAP HANA Cloud Vector Engine\n",
"\n",
"For more information on how to setup the SAP HANA vetor store, take a look at the [documentation](/docs/integrations/vectorstores/sap_hanavector.ipynb).\n",
"\n",
"We use the same setup here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"# Use OPENAI_API_KEY env variable\n",
"# os.environ[\"OPENAI_API_KEY\"] = \"Your OpenAI API key\"\n",
"from hdbcli import dbapi\n",
"\n",
"# Use connection settings from the environment\n",
"connection = dbapi.connect(\n",
" address=os.environ.get(\"HANA_DB_ADDRESS\"),\n",
" port=os.environ.get(\"HANA_DB_PORT\"),\n",
" user=os.environ.get(\"HANA_DB_USER\"),\n",
" password=os.environ.get(\"HANA_DB_PASSWORD\"),\n",
" autocommit=True,\n",
" sslValidateCertificate=False,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to self query with good performance we create additional metadata fields\n",
"for our vectorstore table in HANA:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create custom table with attribute\n",
"cur = connection.cursor()\n",
"cur.execute(\"DROP TABLE LANGCHAIN_DEMO_SELF_QUERY\", ignoreErrors=True)\n",
"cur.execute(\n",
" (\n",
" \"\"\"CREATE TABLE \"LANGCHAIN_DEMO_SELF_QUERY\" (\n",
" \"name\" NVARCHAR(100), \"is_active\" BOOLEAN, \"id\" INTEGER, \"height\" DOUBLE,\n",
" \"VEC_TEXT\" NCLOB, \n",
" \"VEC_META\" NCLOB, \n",
" \"VEC_VECTOR\" REAL_VECTOR\n",
" )\"\"\"\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add some documents."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.vectorstores.hanavector import HanaDB\n",
"from langchain_core.documents import Document\n",
"from langchain_openai import OpenAIEmbeddings\n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"# Prepare some test documents\n",
"docs = [\n",
" Document(\n",
" page_content=\"First\",\n",
" metadata={\"name\": \"adam\", \"is_active\": True, \"id\": 1, \"height\": 10.0},\n",
" ),\n",
" Document(\n",
" page_content=\"Second\",\n",
" metadata={\"name\": \"bob\", \"is_active\": False, \"id\": 2, \"height\": 5.7},\n",
" ),\n",
" Document(\n",
" page_content=\"Third\",\n",
" metadata={\"name\": \"jane\", \"is_active\": True, \"id\": 3, \"height\": 2.4},\n",
" ),\n",
"]\n",
"\n",
"db = HanaDB(\n",
" connection=connection,\n",
" embedding=embeddings,\n",
" table_name=\"LANGCHAIN_DEMO_SELF_QUERY\",\n",
" specific_metadata_columns=[\"name\", \"is_active\", \"id\", \"height\"],\n",
")\n",
"\n",
"# Delete already existing documents from the table\n",
"db.delete(filter={})\n",
"db.add_documents(docs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Self querying\n",
"\n",
"Now for the main act: here is how to construct a SelfQueryRetriever for HANA vectorstore:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.query_constructor.base import AttributeInfo\n",
"from langchain.retrievers.self_query.base import SelfQueryRetriever\n",
"from langchain_community.query_constructors.hanavector import HanaTranslator\n",
"from langchain_openai import ChatOpenAI\n",
"\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\")\n",
"\n",
"metadata_field_info = [\n",
" AttributeInfo(\n",
" name=\"name\",\n",
" description=\"The name of the person\",\n",
" type=\"string\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"is_active\",\n",
" description=\"Whether the person is active\",\n",
" type=\"boolean\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"id\",\n",
" description=\"The ID of the person\",\n",
" type=\"integer\",\n",
" ),\n",
" AttributeInfo(\n",
" name=\"height\",\n",
" description=\"The height of the person\",\n",
" type=\"float\",\n",
" ),\n",
"]\n",
"\n",
"document_content_description = \"A collection of persons\"\n",
"\n",
"hana_translator = HanaTranslator()\n",
"\n",
"retriever = SelfQueryRetriever.from_llm(\n",
" llm,\n",
" db,\n",
" document_content_description,\n",
" metadata_field_info,\n",
" structured_query_translator=hana_translator,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use this retriever to prepare a (self) query for a person:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query_prompt = \"Which person is not active?\"\n",
"\n",
"docs = retriever.invoke(input=query_prompt)\n",
"for doc in docs:\n",
" print(\"-\" * 80)\n",
" print(doc.page_content, \" \", doc.metadata)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also take a look at how the query is being constructed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains.query_constructor.base import (\n",
" StructuredQueryOutputParser,\n",
" get_query_constructor_prompt,\n",
")\n",
"\n",
"prompt = get_query_constructor_prompt(\n",
" document_content_description,\n",
" metadata_field_info,\n",
")\n",
"output_parser = StructuredQueryOutputParser.from_components()\n",
"query_constructor = prompt | llm | output_parser\n",
"\n",
"sq = query_constructor.invoke(input=query_prompt)\n",
"\n",
"print(\"Structured query: \", sq)\n",
"\n",
"print(\"Translated for hana vector store: \", hana_translator.visit_structured_query(sq))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,57 @@
# HANA Translator/query constructor
from typing import Dict, Tuple, Union
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)
class HanaTranslator(Visitor):
"""
Translate internal query language elements to valid filters params for
HANA vectorstore.
"""
allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators."""
allowed_comparators = [
Comparator.EQ,
Comparator.NE,
Comparator.GT,
Comparator.LT,
Comparator.GTE,
Comparator.LTE,
Comparator.IN,
Comparator.NIN,
# Comparator.CONTAIN,
Comparator.LIKE,
]
def _format_func(self, func: Union[Operator, Comparator]) -> str:
self._validate_func(func)
return f"${func.value}"
def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
return {self._format_func(operation.operator): args}
def visit_comparison(self, comparison: Comparison) -> Dict:
return {
comparison.attribute: {
self._format_func(comparison.comparator): comparison.value
}
}
def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"filter": structured_query.filter.accept(self)}
return structured_query.query, kwargs

View File

@ -191,7 +191,8 @@ class HanaDB(VectorStore):
if column_length is not None and column_length > 0:
if rows[0][1] != column_length:
raise AttributeError(
f"Column {column_name} has the wrong length: {rows[0][1]}"
f"Column {column_name} has the wrong length: {rows[0][1]} "
f"expected: {column_length}"
)
else:
raise AttributeError(f"Column {column_name} does not exist")
@ -529,10 +530,18 @@ class HanaDB(VectorStore):
if special_op in COMPARISONS_TO_SQL:
operator = COMPARISONS_TO_SQL[special_op]
if isinstance(special_val, bool):
query_tuple.append("true" if filter_value else "false")
query_tuple.append("true" if special_val else "false")
elif isinstance(special_val, float):
sql_param = "CAST(? as float)"
query_tuple.append(special_val)
elif (
isinstance(special_val, dict)
and "type" in special_val
and special_val["type"] == "date"
):
# Date type
sql_param = "CAST(? as DATE)"
query_tuple.append(special_val["date"])
else:
query_tuple.append(special_val)
# "$between"

View File

@ -0,0 +1,84 @@
from typing import Dict, Tuple
import pytest as pytest
from langchain_core.structured_query import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
)
from langchain_community.query_constructors.hanavector import HanaTranslator
DEFAULT_TRANSLATOR = HanaTranslator()
def test_visit_comparison() -> None:
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
expected = {"foo": {"$lt": 1}}
actual = DEFAULT_TRANSLATOR.visit_comparison(comp)
assert expected == actual
def test_visit_operation() -> None:
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
],
)
expected = {
"$and": [{"foo": {"$lt": 2}}, {"bar": {"$eq": "baz"}}, {"abc": {"$gt": 2.0}}]
}
actual = DEFAULT_TRANSLATOR.visit_operation(op)
assert expected == actual
def test_visit_structured_query() -> None:
query = "What is the capital of France?"
structured_query = StructuredQuery(
query=query,
filter=None,
)
expected: Tuple[str, Dict] = (query, {})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
comp = Comparison(comparator=Comparator.LT, attribute="foo", value=1)
structured_query = StructuredQuery(
query=query,
filter=comp,
)
expected = (query, {"filter": {"foo": {"$lt": 1}}})
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual
op = Operation(
operator=Operator.AND,
arguments=[
Comparison(comparator=Comparator.LT, attribute="foo", value=2),
Comparison(comparator=Comparator.EQ, attribute="bar", value="baz"),
Comparison(comparator=Comparator.GT, attribute="abc", value=2.0),
],
)
structured_query = StructuredQuery(
query=query,
filter=op,
)
expected = (
query,
{
"filter": {
"$and": [
{"foo": {"$lt": 2}},
{"bar": {"$eq": "baz"}},
{"abc": {"$gt": 2.0}},
]
}
},
)
actual = DEFAULT_TRANSLATOR.visit_structured_query(structured_query)
assert expected == actual

View File

@ -177,6 +177,16 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
if isinstance(vectorstore, PGVector):
return NewPGVectorTranslator()
try:
# Added in langchain-community==0.2.11
from langchain_community.query_constructors.hanavector import HanaTranslator
from langchain_community.vectorstores import HanaDB
except ImportError:
pass
else:
if isinstance(vectorstore, HanaDB):
return HanaTranslator()
raise ValueError(
f"Self query retriever with Vector Store type {vectorstore.__class__}"
f" not supported."