Implementing the MMR algorithm for OLAP vector storage (#30033)

Thank you for contributing to LangChain!

-  **Implementing the MMR algorithm for OLAP vector storage**: 
  - Support Apache Doris and StarRocks OLAP database.
- Example: "vectorstore.as_retriever(search_type="mmr",
search_kwargs={"k": 10})"


- **Implementing the MMR algorithm for OLAP vector storage**: 
    - **Apache Doris
    - **StarRocks
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- **Add tests and docs**: 
- Example: "vectorstore.as_retriever(search_type="mmr",
search_kwargs={"k": 10})"


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.

---------

Co-authored-by: fakzhao <fakzhao@cisco.com>
This commit is contained in:
Fakai Zhao 2025-02-28 21:50:22 +08:00 committed by GitHub
parent 186cd7f1a1
commit f07338d2bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 191 additions and 10 deletions

View File

@ -4,16 +4,30 @@ import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import TypedDict
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger()
DEBUG = False
Metadata = Mapping[str, Union[str, int, float, bool]]
class QueryResult(TypedDict):
ids: List[List[str]]
embeddings: List[Any]
documents: List[Document]
metadatas: Optional[List[Metadata]]
distances: Optional[List[float]]
class ApacheDorisSettings(BaseSettings):
"""Apache Doris client configuration.
@ -310,10 +324,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
where_str = ""
q_str = f"""
SELECT {self.config.column_map["document"]},
{self.config.column_map["metadata"]},
SELECT
id as id,
{self.config.column_map["document"]} as document,
{self.config.column_map["metadata"]} as metadata,
cosine_distance(array<float>[{q_emb_str}],
{self.config.column_map["embedding"]}) as dist
{self.config.column_map["embedding"]}) as dist,
{self.config.column_map["embedding"]} as embedding
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY dist {self.dist_order}
@ -371,12 +388,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
"""
q_str = self._build_query_sql(embedding, k, where_str)
try:
q_r = _get_named_result(self.connection, q_str)
return [
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
)
for r in _get_named_result(self.connection, q_str)
for r in q_r
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
@ -430,6 +448,63 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
def metadata_column(self) -> str:
return self.config.column_map["metadata"]
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
q_str = self._build_query_sql(embedding, fetch_k, None)
q_r = _get_named_result(self.connection, q_str)
results = QueryResult(
ids=[r["id"] for r in q_r],
embeddings=[
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
],
documents=[r[self.config.column_map["document"]] for r in q_r],
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
distances=[r["dist"] for r in q_r],
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
results["embeddings"],
k=k,
lambda_mult=lambda_mult,
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def max_marginal_relevance_search(
self,
query: str,
k: int = 5,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
if self.embeddings is None:
raise ValueError(
"For MMR search, you must specify an embedding function oncreation."
)
embedding = self.embeddings.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
where_document=where_document,
)
def _has_mul_sub_str(s: str, *args: Any) -> bool:
"""Check if a string has multiple substrings.
@ -480,3 +555,18 @@ def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
_debug_output(result)
cursor.close()
return result
def _results_to_docs(results: Any) -> List[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"],
results["metadatas"],
results["distances"],
)
]

View File

@ -4,12 +4,16 @@ import json
import logging
from hashlib import sha1
from threading import Thread
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import TypedDict
from langchain_community.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger()
DEBUG = False
@ -66,6 +70,17 @@ def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
return result
Metadata = Mapping[str, Union[str, int, float, bool]]
class QueryResult(TypedDict):
ids: List[List[str]]
embeddings: List[Any]
documents: List[Document]
metadatas: Optional[List[Metadata]]
distances: Optional[List[float]]
class StarRocksSettings(BaseSettings):
"""StarRocks client configuration.
@ -363,10 +378,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
where_str = ""
q_str = f"""
SELECT {self.config.column_map["document"]},
{self.config.column_map["metadata"]},
SELECT
id as id,
{self.config.column_map["document"]} as document,
{self.config.column_map["metadata"]} as metadata,
cosine_similarity_norm(array<float>[{q_emb_str}],
{self.config.column_map["embedding"]}) as dist
{self.config.column_map["embedding"]}) as dist,
{self.config.column_map["embedding"]} as embedding
FROM {self.config.database}.{self.config.table}
{where_str}
ORDER BY dist {self.dist_order}
@ -424,12 +442,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
"""
q_str = self._build_query_sql(embedding, k, where_str)
try:
q_r = get_named_result(self.connection, q_str)
return [
Document(
page_content=r[self.config.column_map["document"]],
metadata=json.loads(r[self.config.column_map["metadata"]]),
)
for r in get_named_result(self.connection, q_str)
for r in q_r
]
except Exception as e:
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
@ -484,3 +503,75 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
@property
def metadata_column(self) -> str:
return self.config.column_map["metadata"]
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> list[Document]:
q_str = self._build_query_sql(embedding, fetch_k, None)
q_r = get_named_result(self.connection, q_str)
results = QueryResult(
ids=[r["id"] for r in q_r],
embeddings=[
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
],
documents=[r[self.config.column_map["document"]] for r in q_r],
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
distances=[r["dist"] for r in q_r],
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
results["embeddings"],
k=k,
lambda_mult=lambda_mult,
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def max_marginal_relevance_search(
self,
query: str,
k: int = 5,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
where_document: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
if self.embeddings is None:
raise ValueError(
"For MMR search, you must specify an embedding function oncreation."
)
embedding = self.embeddings.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
where_document=where_document,
)
def _results_to_docs(results: Any) -> List[Document]:
return [doc for doc, _ in _results_to_docs_and_scores(results)]
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"],
results["metadatas"],
results["distances"],
)
]