mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 07:36:08 +00:00
Implementing the MMR algorithm for OLAP vector storage (#30033)
Thank you for contributing to LangChain! - **Implementing the MMR algorithm for OLAP vector storage**: - Support Apache Doris and StarRocks OLAP database. - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - **Implementing the MMR algorithm for OLAP vector storage**: - **Apache Doris - **StarRocks - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - **Add tests and docs**: - Example: "vectorstore.as_retriever(search_type="mmr", search_kwargs={"k": 10})" - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --------- Co-authored-by: fakzhao <fakzhao@cisco.com>
This commit is contained in:
parent
186cd7f1a1
commit
f07338d2bf
@ -4,16 +4,30 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
|
|
||||||
|
Metadata = Mapping[str, Union[str, int, float, bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(TypedDict):
|
||||||
|
ids: List[List[str]]
|
||||||
|
embeddings: List[Any]
|
||||||
|
documents: List[Document]
|
||||||
|
metadatas: Optional[List[Metadata]]
|
||||||
|
distances: Optional[List[float]]
|
||||||
|
|
||||||
|
|
||||||
class ApacheDorisSettings(BaseSettings):
|
class ApacheDorisSettings(BaseSettings):
|
||||||
"""Apache Doris client configuration.
|
"""Apache Doris client configuration.
|
||||||
@ -310,10 +324,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
where_str = ""
|
where_str = ""
|
||||||
|
|
||||||
q_str = f"""
|
q_str = f"""
|
||||||
SELECT {self.config.column_map["document"]},
|
SELECT
|
||||||
{self.config.column_map["metadata"]},
|
id as id,
|
||||||
|
{self.config.column_map["document"]} as document,
|
||||||
|
{self.config.column_map["metadata"]} as metadata,
|
||||||
cosine_distance(array<float>[{q_emb_str}],
|
cosine_distance(array<float>[{q_emb_str}],
|
||||||
{self.config.column_map["embedding"]}) as dist
|
{self.config.column_map["embedding"]}) as dist,
|
||||||
|
{self.config.column_map["embedding"]} as embedding
|
||||||
FROM {self.config.database}.{self.config.table}
|
FROM {self.config.database}.{self.config.table}
|
||||||
{where_str}
|
{where_str}
|
||||||
ORDER BY dist {self.dist_order}
|
ORDER BY dist {self.dist_order}
|
||||||
@ -371,12 +388,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
"""
|
"""
|
||||||
q_str = self._build_query_sql(embedding, k, where_str)
|
q_str = self._build_query_sql(embedding, k, where_str)
|
||||||
try:
|
try:
|
||||||
|
q_r = _get_named_result(self.connection, q_str)
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=r[self.config.column_map["document"]],
|
page_content=r[self.config.column_map["document"]],
|
||||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||||
)
|
)
|
||||||
for r in _get_named_result(self.connection, q_str)
|
for r in q_r
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||||
@ -430,6 +448,63 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
def metadata_column(self) -> str:
|
def metadata_column(self) -> str:
|
||||||
return self.config.column_map["metadata"]
|
return self.config.column_map["metadata"]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[Document]:
|
||||||
|
q_str = self._build_query_sql(embedding, fetch_k, None)
|
||||||
|
q_r = _get_named_result(self.connection, q_str)
|
||||||
|
results = QueryResult(
|
||||||
|
ids=[r["id"] for r in q_r],
|
||||||
|
embeddings=[
|
||||||
|
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
|
||||||
|
],
|
||||||
|
documents=[r[self.config.column_map["document"]] for r in q_r],
|
||||||
|
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
|
||||||
|
distances=[r["dist"] for r in q_r],
|
||||||
|
)
|
||||||
|
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32),
|
||||||
|
results["embeddings"],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = _results_to_docs(results)
|
||||||
|
|
||||||
|
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||||
|
return selected_results
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 5,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
if self.embeddings is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For MMR search, you must specify an embedding function oncreation."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = self.embeddings.embed_query(query)
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
filter=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _has_mul_sub_str(s: str, *args: Any) -> bool:
|
def _has_mul_sub_str(s: str, *args: Any) -> bool:
|
||||||
"""Check if a string has multiple substrings.
|
"""Check if a string has multiple substrings.
|
||||||
@ -480,3 +555,18 @@ def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
|
|||||||
_debug_output(result)
|
_debug_output(result)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs(results: Any) -> List[Document]:
|
||||||
|
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||||
|
return [
|
||||||
|
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||||
|
for result in zip(
|
||||||
|
results["documents"],
|
||||||
|
results["metadatas"],
|
||||||
|
results["distances"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
@ -4,12 +4,16 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from hashlib import sha1
|
from hashlib import sha1
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
@ -66,6 +70,17 @@ def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
Metadata = Mapping[str, Union[str, int, float, bool]]
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(TypedDict):
|
||||||
|
ids: List[List[str]]
|
||||||
|
embeddings: List[Any]
|
||||||
|
documents: List[Document]
|
||||||
|
metadatas: Optional[List[Metadata]]
|
||||||
|
distances: Optional[List[float]]
|
||||||
|
|
||||||
|
|
||||||
class StarRocksSettings(BaseSettings):
|
class StarRocksSettings(BaseSettings):
|
||||||
"""StarRocks client configuration.
|
"""StarRocks client configuration.
|
||||||
|
|
||||||
@ -363,10 +378,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
where_str = ""
|
where_str = ""
|
||||||
|
|
||||||
q_str = f"""
|
q_str = f"""
|
||||||
SELECT {self.config.column_map["document"]},
|
SELECT
|
||||||
{self.config.column_map["metadata"]},
|
id as id,
|
||||||
|
{self.config.column_map["document"]} as document,
|
||||||
|
{self.config.column_map["metadata"]} as metadata,
|
||||||
cosine_similarity_norm(array<float>[{q_emb_str}],
|
cosine_similarity_norm(array<float>[{q_emb_str}],
|
||||||
{self.config.column_map["embedding"]}) as dist
|
{self.config.column_map["embedding"]}) as dist,
|
||||||
|
{self.config.column_map["embedding"]} as embedding
|
||||||
FROM {self.config.database}.{self.config.table}
|
FROM {self.config.database}.{self.config.table}
|
||||||
{where_str}
|
{where_str}
|
||||||
ORDER BY dist {self.dist_order}
|
ORDER BY dist {self.dist_order}
|
||||||
@ -424,12 +442,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
"""
|
"""
|
||||||
q_str = self._build_query_sql(embedding, k, where_str)
|
q_str = self._build_query_sql(embedding, k, where_str)
|
||||||
try:
|
try:
|
||||||
|
q_r = get_named_result(self.connection, q_str)
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=r[self.config.column_map["document"]],
|
page_content=r[self.config.column_map["document"]],
|
||||||
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
metadata=json.loads(r[self.config.column_map["metadata"]]),
|
||||||
)
|
)
|
||||||
for r in get_named_result(self.connection, q_str)
|
for r in q_r
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
||||||
@ -484,3 +503,75 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}(
|
|||||||
@property
|
@property
|
||||||
def metadata_column(self) -> str:
|
def metadata_column(self) -> str:
|
||||||
return self.config.column_map["metadata"]
|
return self.config.column_map["metadata"]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[Document]:
|
||||||
|
q_str = self._build_query_sql(embedding, fetch_k, None)
|
||||||
|
q_r = get_named_result(self.connection, q_str)
|
||||||
|
results = QueryResult(
|
||||||
|
ids=[r["id"] for r in q_r],
|
||||||
|
embeddings=[
|
||||||
|
json.loads(r[self.config.column_map["embedding"]]) for r in q_r
|
||||||
|
],
|
||||||
|
documents=[r[self.config.column_map["document"]] for r in q_r],
|
||||||
|
metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r],
|
||||||
|
distances=[r["dist"] for r in q_r],
|
||||||
|
)
|
||||||
|
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32),
|
||||||
|
results["embeddings"],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = _results_to_docs(results)
|
||||||
|
|
||||||
|
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||||
|
return selected_results
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 5,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
where_document: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
if self.embeddings is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For MMR search, you must specify an embedding function oncreation."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = self.embeddings.embed_query(query)
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
filter=filter,
|
||||||
|
where_document=where_document,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs(results: Any) -> List[Document]:
|
||||||
|
return [doc for doc, _ in _results_to_docs_and_scores(results)]
|
||||||
|
|
||||||
|
|
||||||
|
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||||
|
return [
|
||||||
|
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||||
|
for result in zip(
|
||||||
|
results["documents"],
|
||||||
|
results["metadatas"],
|
||||||
|
results["distances"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user