mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Matching engine, return doc id (#14930)
This commit is contained in:
parent
8a3360edf6
commit
345acb26ac
@ -22,7 +22,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from langchain_community.embeddings import TensorflowHubEmbeddings
|
from langchain_community.embeddings import TensorflowHubEmbeddings
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MatchingEngine(VectorStore):
|
class MatchingEngine(VectorStore):
|
||||||
@ -49,6 +49,8 @@ class MatchingEngine(VectorStore):
|
|||||||
gcs_client: storage.Client,
|
gcs_client: storage.Client,
|
||||||
gcs_bucket_name: str,
|
gcs_bucket_name: str,
|
||||||
credentials: Optional[Credentials] = None,
|
credentials: Optional[Credentials] = None,
|
||||||
|
*,
|
||||||
|
document_id_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Google Vertex AI Vector Search (previously Matching Engine)
|
"""Google Vertex AI Vector Search (previously Matching Engine)
|
||||||
implementation of the vector store.
|
implementation of the vector store.
|
||||||
@ -78,6 +80,9 @@ class MatchingEngine(VectorStore):
|
|||||||
gcs_client: The GCS client.
|
gcs_client: The GCS client.
|
||||||
gcs_bucket_name: The GCS bucket name.
|
gcs_bucket_name: The GCS bucket name.
|
||||||
credentials (Optional): Created GCP credentials.
|
credentials (Optional): Created GCP credentials.
|
||||||
|
document_id_key (Optional): Key for storing document ID in document
|
||||||
|
metadata. If None, document ID will not be returned in document
|
||||||
|
metadata.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._validate_google_libraries_installation()
|
self._validate_google_libraries_installation()
|
||||||
@ -89,6 +94,7 @@ class MatchingEngine(VectorStore):
|
|||||||
self.gcs_client = gcs_client
|
self.gcs_client = gcs_client
|
||||||
self.credentials = credentials
|
self.credentials = credentials
|
||||||
self.gcs_bucket_name = gcs_bucket_name
|
self.gcs_bucket_name = gcs_bucket_name
|
||||||
|
self.document_id_key = document_id_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embeddings(self) -> Embeddings:
|
def embeddings(self) -> Embeddings:
|
||||||
@ -229,6 +235,7 @@ class MatchingEngine(VectorStore):
|
|||||||
List[Tuple[Document, float]]: List of documents most similar to
|
List[Tuple[Document, float]]: List of documents most similar to
|
||||||
the query text and cosine distance in float for each.
|
the query text and cosine distance in float for each.
|
||||||
Lower score represents more similarity.
|
Lower score represents more similarity.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
filter = filter or []
|
filter = filter or []
|
||||||
|
|
||||||
@ -255,19 +262,27 @@ class MatchingEngine(VectorStore):
|
|||||||
if len(response) == 0:
|
if len(response) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
results = []
|
docs: List[Tuple[Document, float]] = []
|
||||||
|
|
||||||
# I'm only getting the first one because queries receives an array
|
# I'm only getting the first one because queries receives an array
|
||||||
# and the similarity_search method only receives one query. This
|
# and the similarity_search method only receives one query. This
|
||||||
# means that the match method will always return an array with only
|
# means that the match method will always return an array with only
|
||||||
# one element.
|
# one element.
|
||||||
for doc in response[0]:
|
for result in response[0]:
|
||||||
page_content = self._download_from_gcs(f"documents/{doc.id}")
|
page_content = self._download_from_gcs(f"documents/{result.id}")
|
||||||
results.append((Document(page_content=page_content), doc.distance))
|
# TODO: return all metadata.
|
||||||
|
metadata = {}
|
||||||
|
if self.document_id_key is not None:
|
||||||
|
metadata[self.document_id_key] = result.id
|
||||||
|
document = Document(
|
||||||
|
page_content=page_content,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
docs.append((document, result.distance))
|
||||||
|
|
||||||
logger.debug("Downloaded documents for query.")
|
logger.debug("Downloaded documents for query.")
|
||||||
|
|
||||||
return results
|
return docs
|
||||||
|
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
@ -382,6 +397,7 @@ class MatchingEngine(VectorStore):
|
|||||||
endpoint_id: str,
|
endpoint_id: str,
|
||||||
credentials_path: Optional[str] = None,
|
credentials_path: Optional[str] = None,
|
||||||
embedding: Optional[Embeddings] = None,
|
embedding: Optional[Embeddings] = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> "MatchingEngine":
|
) -> "MatchingEngine":
|
||||||
"""Takes the object creation out of the constructor.
|
"""Takes the object creation out of the constructor.
|
||||||
|
|
||||||
@ -397,6 +413,7 @@ class MatchingEngine(VectorStore):
|
|||||||
the local file system.
|
the local file system.
|
||||||
embedding: The :class:`Embeddings` that will be used for
|
embedding: The :class:`Embeddings` that will be used for
|
||||||
embedding the texts.
|
embedding the texts.
|
||||||
|
kwargs: Additional keyword arguments to pass to MatchingEngine.__init__().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured MatchingEngine with the texts added to the index.
|
A configured MatchingEngine with the texts added to the index.
|
||||||
@ -419,6 +436,7 @@ class MatchingEngine(VectorStore):
|
|||||||
gcs_client=gcs_client,
|
gcs_client=gcs_client,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
gcs_bucket_name=gcs_bucket_name,
|
gcs_bucket_name=gcs_bucket_name,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
Reference in New Issue
Block a user