From f07338d2bfb71e1eba057c30b65b3ef602702436 Mon Sep 17 00:00:00 2001 From: Fakai Zhao Date: Fri, 28 Feb 2025 21:50:22 +0800 Subject: [PATCH] Implementing the MMR algorithm for OLAP vector storage (#30033) Thank you for contributing to LangChain! - **Implementing the MMR algorithm for OLAP vector storage**: - Support Apache Doris and StarRocks OLAP database. - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - **Implementing the MMR algorithm for OLAP vector storage**: - **Apache Doris - **StarRocks - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - **Add tests and docs**: - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: fakzhao --- .../vectorstores/apache_doris.py | 100 ++++++++++++++++- .../vectorstores/starrocks.py | 101 +++++++++++++++++- 2 files changed, 191 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/apache_doris.py b/libs/community/langchain_community/vectorstores/apache_doris.py index 56ee6c0f64f..0e88fba5bdf 100644 --- a/libs/community/langchain_community/vectorstores/apache_doris.py +++ b/libs/community/langchain_community/vectorstores/apache_doris.py @@ -4,16 +4,30 @@ import json import logging from hashlib import sha1 from threading import Thread -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pydantic_settings import BaseSettings, SettingsConfigDict +from typing_extensions import TypedDict + +from langchain_community.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger() DEBUG = False +Metadata = Mapping[str, Union[str, int, float, bool]] + + +class QueryResult(TypedDict): + ids: List[List[str]] + embeddings: List[Any] + documents: List[Document] + metadatas: Optional[List[Metadata]] + distances: Optional[List[float]] + class ApacheDorisSettings(BaseSettings): """Apache Doris client configuration. @@ -310,10 +324,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( where_str = "" q_str = f""" - SELECT {self.config.column_map["document"]}, - {self.config.column_map["metadata"]}, + SELECT + id as id, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_distance(array[{q_emb_str}], - {self.config.column_map["embedding"]}) as dist + {self.config.column_map["embedding"]}) as dist, + {self.config.column_map["embedding"]} as embedding FROM {self.config.database}.{self.config.table} {where_str} ORDER BY dist {self.dist_order} @@ -371,12 +388,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( """ q_str = self._build_query_sql(embedding, k, where_str) try: + q_r = _get_named_result(self.connection, q_str) return [ Document( page_content=r[self.config.column_map["document"]], metadata=json.loads(r[self.config.column_map["metadata"]]), ) - for r in _get_named_result(self.connection, q_str) + for r in q_r ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") @@ -430,6 +448,63 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( def metadata_column(self) -> str: return self.config.column_map["metadata"] + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + q_str = self._build_query_sql(embedding, fetch_k, None) + q_r = _get_named_result(self.connection, q_str) + results = QueryResult( + ids=[r["id"] for r in q_r], + embeddings=[ + json.loads(r[self.config.column_map["embedding"]]) for r in q_r + ], + documents=[r[self.config.column_map["document"]] for r in q_r], + metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r], + distances=[r["dist"] for r in q_r], + ) + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + results["embeddings"], + k=k, + lambda_mult=lambda_mult, + ) + + candidates = _results_to_docs(results) + + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 5, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + if self.embeddings is None: + raise ValueError( + "For MMR search, you must specify an embedding function oncreation." + ) + + embedding = self.embeddings.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + where_document=where_document, + ) + def _has_mul_sub_str(s: str, *args: Any) -> bool: """Check if a string has multiple substrings. @@ -480,3 +555,18 @@ def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]: _debug_output(result) cursor.close() return result + + +def _results_to_docs(results: Any) -> List[Document]: + return [doc for doc, _ in _results_to_docs_and_scores(results)] + + +def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"], + results["metadatas"], + results["distances"], + ) + ] diff --git a/libs/community/langchain_community/vectorstores/starrocks.py b/libs/community/langchain_community/vectorstores/starrocks.py index 9298f12a78f..d3ce2dcc9b5 100644 --- a/libs/community/langchain_community/vectorstores/starrocks.py +++ b/libs/community/langchain_community/vectorstores/starrocks.py @@ -4,12 +4,16 @@ import json import logging from hashlib import sha1 from threading import Thread -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pydantic_settings import BaseSettings, SettingsConfigDict +from typing_extensions import TypedDict + +from langchain_community.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger() DEBUG = False @@ -66,6 +70,17 @@ def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]: return result +Metadata = Mapping[str, Union[str, int, float, bool]] + + +class QueryResult(TypedDict): + ids: List[List[str]] + embeddings: List[Any] + documents: List[Document] + metadatas: Optional[List[Metadata]] + distances: Optional[List[float]] + + class StarRocksSettings(BaseSettings): """StarRocks client configuration. @@ -363,10 +378,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( where_str = "" q_str = f""" - SELECT {self.config.column_map["document"]}, - {self.config.column_map["metadata"]}, + SELECT + id as id, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_similarity_norm(array[{q_emb_str}], - {self.config.column_map["embedding"]}) as dist + {self.config.column_map["embedding"]}) as dist, + {self.config.column_map["embedding"]} as embedding FROM {self.config.database}.{self.config.table} {where_str} ORDER BY dist {self.dist_order} @@ -424,12 +442,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( """ q_str = self._build_query_sql(embedding, k, where_str) try: + q_r = get_named_result(self.connection, q_str) return [ Document( page_content=r[self.config.column_map["document"]], metadata=json.loads(r[self.config.column_map["metadata"]]), ) - for r in get_named_result(self.connection, q_str) + for r in q_r ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") @@ -484,3 +503,75 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( @property def metadata_column(self) -> str: return self.config.column_map["metadata"] + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + q_str = self._build_query_sql(embedding, fetch_k, None) + q_r = get_named_result(self.connection, q_str) + results = QueryResult( + ids=[r["id"] for r in q_r], + embeddings=[ + json.loads(r[self.config.column_map["embedding"]]) for r in q_r + ], + documents=[r[self.config.column_map["document"]] for r in q_r], + metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r], + distances=[r["dist"] for r in q_r], + ) + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + results["embeddings"], + k=k, + lambda_mult=lambda_mult, + ) + + candidates = _results_to_docs(results) + + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 5, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + if self.embeddings is None: + raise ValueError( + "For MMR search, you must specify an embedding function oncreation." + ) + + embedding = self.embeddings.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + where_document=where_document, + ) + + +def _results_to_docs(results: Any) -> List[Document]: + return [doc for doc, _ in _results_to_docs_and_scores(results)] + + +def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"], + results["metadatas"], + results["distances"], + ) + ]