diff --git a/libs/community/langchain_community/vectorstores/apache_doris.py b/libs/community/langchain_community/vectorstores/apache_doris.py index 56ee6c0f64f..0e88fba5bdf 100644 --- a/libs/community/langchain_community/vectorstores/apache_doris.py +++ b/libs/community/langchain_community/vectorstores/apache_doris.py @@ -4,16 +4,30 @@ import json import logging from hashlib import sha1 from threading import Thread -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pydantic_settings import BaseSettings, SettingsConfigDict +from typing_extensions import TypedDict + +from langchain_community.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger() DEBUG = False +Metadata = Mapping[str, Union[str, int, float, bool]] + + +class QueryResult(TypedDict): + ids: List[List[str]] + embeddings: List[Any] + documents: List[Document] + metadatas: Optional[List[Metadata]] + distances: Optional[List[float]] + class ApacheDorisSettings(BaseSettings): """Apache Doris client configuration. @@ -310,10 +324,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( where_str = "" q_str = f""" - SELECT {self.config.column_map["document"]}, - {self.config.column_map["metadata"]}, + SELECT + id as id, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_distance(array[{q_emb_str}], - {self.config.column_map["embedding"]}) as dist + {self.config.column_map["embedding"]}) as dist, + {self.config.column_map["embedding"]} as embedding FROM {self.config.database}.{self.config.table} {where_str} ORDER BY dist {self.dist_order} @@ -371,12 +388,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( """ q_str = self._build_query_sql(embedding, k, where_str) try: + q_r = _get_named_result(self.connection, q_str) return [ Document( page_content=r[self.config.column_map["document"]], metadata=json.loads(r[self.config.column_map["metadata"]]), ) - for r in _get_named_result(self.connection, q_str) + for r in q_r ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") @@ -430,6 +448,63 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( def metadata_column(self) -> str: return self.config.column_map["metadata"] + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + q_str = self._build_query_sql(embedding, fetch_k, None) + q_r = _get_named_result(self.connection, q_str) + results = QueryResult( + ids=[r["id"] for r in q_r], + embeddings=[ + json.loads(r[self.config.column_map["embedding"]]) for r in q_r + ], + documents=[r[self.config.column_map["document"]] for r in q_r], + metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r], + distances=[r["dist"] for r in q_r], + ) + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + results["embeddings"], + k=k, + lambda_mult=lambda_mult, + ) + + candidates = _results_to_docs(results) + + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 5, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + if self.embeddings is None: + raise ValueError( + "For MMR search, you must specify an embedding function oncreation." + ) + + embedding = self.embeddings.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + where_document=where_document, + ) + def _has_mul_sub_str(s: str, *args: Any) -> bool: """Check if a string has multiple substrings. @@ -480,3 +555,18 @@ def _get_named_result(connection: Any, query: str) -> List[dict[str, Any]]: _debug_output(result) cursor.close() return result + + +def _results_to_docs(results: Any) -> List[Document]: + return [doc for doc, _ in _results_to_docs_and_scores(results)] + + +def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"], + results["metadatas"], + results["distances"], + ) + ] diff --git a/libs/community/langchain_community/vectorstores/starrocks.py b/libs/community/langchain_community/vectorstores/starrocks.py index 9298f12a78f..d3ce2dcc9b5 100644 --- a/libs/community/langchain_community/vectorstores/starrocks.py +++ b/libs/community/langchain_community/vectorstores/starrocks.py @@ -4,12 +4,16 @@ import json import logging from hashlib import sha1 from threading import Thread -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pydantic_settings import BaseSettings, SettingsConfigDict +from typing_extensions import TypedDict + +from langchain_community.vectorstores.utils import maximal_marginal_relevance logger = logging.getLogger() DEBUG = False @@ -66,6 +70,17 @@ def get_named_result(connection: Any, query: str) -> List[dict[str, Any]]: return result +Metadata = Mapping[str, Union[str, int, float, bool]] + + +class QueryResult(TypedDict): + ids: List[List[str]] + embeddings: List[Any] + documents: List[Document] + metadatas: Optional[List[Metadata]] + distances: Optional[List[float]] + + class StarRocksSettings(BaseSettings): """StarRocks client configuration. @@ -363,10 +378,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( where_str = "" q_str = f""" - SELECT {self.config.column_map["document"]}, - {self.config.column_map["metadata"]}, + SELECT + id as id, + {self.config.column_map["document"]} as document, + {self.config.column_map["metadata"]} as metadata, cosine_similarity_norm(array[{q_emb_str}], - {self.config.column_map["embedding"]}) as dist + {self.config.column_map["embedding"]}) as dist, + {self.config.column_map["embedding"]} as embedding FROM {self.config.database}.{self.config.table} {where_str} ORDER BY dist {self.dist_order} @@ -424,12 +442,13 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( """ q_str = self._build_query_sql(embedding, k, where_str) try: + q_r = get_named_result(self.connection, q_str) return [ Document( page_content=r[self.config.column_map["document"]], metadata=json.loads(r[self.config.column_map["metadata"]]), ) - for r in get_named_result(self.connection, q_str) + for r in q_r ] except Exception as e: logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m") @@ -484,3 +503,75 @@ CREATE TABLE IF NOT EXISTS {self.config.database}.{self.config.table}( @property def metadata_column(self) -> str: return self.config.column_map["metadata"] + + def max_marginal_relevance_search_by_vector( + self, + embedding: list[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + q_str = self._build_query_sql(embedding, fetch_k, None) + q_r = get_named_result(self.connection, q_str) + results = QueryResult( + ids=[r["id"] for r in q_r], + embeddings=[ + json.loads(r[self.config.column_map["embedding"]]) for r in q_r + ], + documents=[r[self.config.column_map["document"]] for r in q_r], + metadatas=[json.loads(r[self.config.column_map["metadata"]]) for r in q_r], + distances=[r["dist"] for r in q_r], + ) + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + results["embeddings"], + k=k, + lambda_mult=lambda_mult, + ) + + candidates = _results_to_docs(results) + + selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected] + return selected_results + + def max_marginal_relevance_search( + self, + query: str, + k: int = 5, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + where_document: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + if self.embeddings is None: + raise ValueError( + "For MMR search, you must specify an embedding function oncreation." + ) + + embedding = self.embeddings.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, + k, + fetch_k, + lambda_mult=lambda_mult, + filter=filter, + where_document=where_document, + ) + + +def _results_to_docs(results: Any) -> List[Document]: + return [doc for doc, _ in _results_to_docs_and_scores(results)] + + +def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]: + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"], + results["metadatas"], + results["distances"], + ) + ]