mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
Implementing the MMR algorithm for OLAP vector storage (#30033)
Thank you for contributing to LangChain! - **Implementing the MMR algorithm for OLAP vector storage**: - Support Apache Doris and StarRocks OLAP database. - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - **Implementing the MMR algorithm for OLAP vector storage**: - **Apache Doris - **StarRocks - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - **Add tests and docs**: - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: fakzhao <fakzhao@cisco.com>
This commit is contained in:
parent
186cd7f1a1
commit
f07338d2bf
@ -4,16 +4,30 @@ import json
|
||||
import logging
|
||||
from hashlib import sha1
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger()
|
||||
DEBUG = False
|
||||
|
||||
Metadata = Mapping[str, Union[str, int, float, bool]]
|
||||
|
||||
|
||||
class QueryResult(TypedDict):
|
||||
ids: List[List[str]]
|
||||
embeddings: List[Any]
|
||||
documents: List[Document]
|
||||
metadatas: Optional[List[Metadata]]
|
||||
distances: Optional[List[float]]
|
||||
|
||||
|
||||
class ApacheDorisSettings(BaseSettings):
|
||||
"""Apache Doris client configuration.
|
||||
@ -310,10 +324,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
where_str = ""
|
||||
|
||||
q_str = f"""
|
||||
SELECT {self.config.column_map["document"]},
|
||||
{self.config.column_map["metadata"]},
|
||||
SELECT
|
||||
id as id,
|
||||
{self.config.column_map["document"]} as document,
|
||||
{self.config.column_map["metadata"]} as metadata,
|
||||
cosine_distance(array<float>[{q_emb_str}],
|
||||
{self.config.column_map["embedding"]}) as dist
|
||||
{self.config.column_map["embedding"]}) as dist,
|
||||
{self.config.column_map["embedding"]} as embedding
|
||||
FROM {self.config.database}.{self.config.table}
|
||||
{where_str}
|
||||
ORDER BY dist {self.dist_order}
|
||||
@ -371,12 +388,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
"""
|
||||
q_str = self._build_query_sql(embedding, k, where_str)
|
||||
try:
|
||||
q_r = _get_named_result(self.connection, q_str)
|
||||
return [
|
||||
Document(
|
||||
page_content=r[self.config.column_map["document"]],
|
||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||
)
|
||||
for r in _get_named_result(self.connection, q_str)
|
||||
for r in q_r
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
@ -430,6 +448,63 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
def metadata_column(self) -> str:
|
||||
return self.config.column_map["metadata"]
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
q_str = self._build_query_sql(embedding, fetch_k, None)
|
||||
q_r = _get_named_result(self.connection, q_str)
|
||||
results = QueryResult(
|
||||
ids=[r["id"] for r in q_r],
|
||||
embeddings=[
|
||||
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
|
||||
],
|
||||
documents=[r[self.config.column_map["document"]] for r in q_r],
|
||||
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
|
||||
distances=[r["dist"] for r in q_r],
|
||||
)
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
results["embeddings"],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
candidates = _results_to_docs(results)
|
||||
|
||||
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.embeddings is None:
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function oncreation."
|
||||
)
|
||||
|
||||
embedding = self.embeddings.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
|
||||
|
||||
def _has_mul_sub_str(s: str, *args: Any) -> bool:
|
||||
"""Check if a string has multiple substrings.
|
||||
@ -480,3 +555,18 @@ def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
|
||||
_debug_output(result)
|
||||
cursor.close()
|
||||
return result
|
||||
|
||||
|
||||
def _results_to_docs(results: Any) -> List[Document]:
|
||||
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
return [
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"],
|
||||
results["metadatas"],
|
||||
results["distances"],
|
||||
)
|
||||
]
|
||||
|
@ -4,12 +4,16 @@ import json
|
||||
import logging
|
||||
from hashlib import sha1
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger()
|
||||
DEBUG = False
|
||||
@ -66,6 +70,17 @@ def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
|
||||
return result
|
||||
|
||||
|
||||
Metadata = Mapping[str, Union[str, int, float, bool]]
|
||||
|
||||
|
||||
class QueryResult(TypedDict):
|
||||
ids: List[List[str]]
|
||||
embeddings: List[Any]
|
||||
documents: List[Document]
|
||||
metadatas: Optional[List[Metadata]]
|
||||
distances: Optional[List[float]]
|
||||
|
||||
|
||||
class StarRocksSettings(BaseSettings):
|
||||
"""StarRocks client configuration.
|
||||
|
||||
@ -363,10 +378,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
where_str = ""
|
||||
|
||||
q_str = f"""
|
||||
SELECT {self.config.column_map["document"]},
|
||||
{self.config.column_map["metadata"]},
|
||||
SELECT
|
||||
id as id,
|
||||
{self.config.column_map["document"]} as document,
|
||||
{self.config.column_map["metadata"]} as metadata,
|
||||
cosine_similarity_norm(array<float>[{q_emb_str}],
|
||||
{self.config.column_map["embedding"]}) as dist
|
||||
{self.config.column_map["embedding"]}) as dist,
|
||||
{self.config.column_map["embedding"]} as embedding
|
||||
FROM {self.config.database}.{self.config.table}
|
||||
{where_str}
|
||||
ORDER BY dist {self.dist_order}
|
||||
@ -424,12 +442,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
"""
|
||||
q_str = self._build_query_sql(embedding, k, where_str)
|
||||
try:
|
||||
q_r = get_named_result(self.connection, q_str)
|
||||
return [
|
||||
Document(
|
||||
page_content=r[self.config.column_map["document"]],
|
||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||
)
|
||||
for r in get_named_result(self.connection, q_str)
|
||||
for r in q_r
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||
@ -484,3 +503,75 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
||||
@property
|
||||
def metadata_column(self) -> str:
|
||||
return self.config.column_map["metadata"]
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: list[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
q_str = self._build_query_sql(embedding, fetch_k, None)
|
||||
q_r = get_named_result(self.connection, q_str)
|
||||
results = QueryResult(
|
||||
ids=[r["id"] for r in q_r],
|
||||
embeddings=[
|
||||
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
|
||||
],
|
||||
documents=[r[self.config.column_map["document"]] for r in q_r],
|
||||
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
|
||||
distances=[r["dist"] for r in q_r],
|
||||
)
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
results["embeddings"],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
candidates = _results_to_docs(results)
|
||||
|
||||
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
where_document: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
if self.embeddings is None:
|
||||
raise ValueError(
|
||||
"For MMR search, you must specify an embedding function oncreation."
|
||||
)
|
||||
|
||||
embedding = self.embeddings.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
where_document=where_document,
|
||||
)
|
||||
|
||||
|
||||
def _results_to_docs(results: Any) -> List[Document]:
|
||||
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
return [
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"],
|
||||
results["metadatas"],
|
||||
results["distances"],
|
||||
)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user