From e77eeee6ee8253710b9427565d3c3b987827cccf Mon Sep 17 00:00:00 2001 From: ccurme Date: Mon, 12 Aug 2024 10:51:59 -0400 Subject: [PATCH] core[patch]: add standard tracing params for retrievers (#25240) --- .../tests/unit_tests/retrievers/test_base.py | 10 +++++ .../unit_tests/retrievers/test_bedrock.py | 5 +++ .../test_databricks_vector_search.py | 22 ++++++++++ .../unit_tests/vectorstores/test_faiss.py | 9 +++++ libs/core/langchain_core/retrievers.py | 40 ++++++++++++++++++- libs/core/langchain_core/vectorstores/base.py | 21 +++++++++- 6 files changed, 104 insertions(+), 3 deletions(-) diff --git a/libs/community/tests/unit_tests/retrievers/test_base.py b/libs/community/tests/unit_tests/retrievers/test_base.py index 92d666dbedb..e8665ff2028 100644 --- a/libs/community/tests/unit_tests/retrievers/test_base.py +++ b/libs/community/tests/unit_tests/retrievers/test_base.py @@ -74,6 +74,11 @@ async def test_fake_retriever_v1_upgrade_async( assert callbacks.retriever_errors == 0 +def test_fake_retriever_v1_standard_params(fake_retriever_v1: BaseRetriever) -> None: + ls_params = fake_retriever_v1._get_ls_params() + assert ls_params == {"ls_retriever_name": "fakeretrieverv1"} + + @pytest.fixture def fake_retriever_v1_with_kwargs() -> BaseRetriever: # Test for things like the Weaviate V1 Retriever. @@ -213,3 +218,8 @@ async def test_fake_retriever_v2_async( await fake_erroring_retriever_v2.ainvoke( "Foo", config={"callbacks": [callbacks]} ) + + +def test_fake_retriever_v2_standard_params(fake_retriever_v2: BaseRetriever) -> None: + ls_params = fake_retriever_v2._get_ls_params() + assert ls_params == {"ls_retriever_name": "fakeretrieverv2"} diff --git a/libs/community/tests/unit_tests/retrievers/test_bedrock.py b/libs/community/tests/unit_tests/retrievers/test_bedrock.py index de954e6e192..ff72d193e4a 100644 --- a/libs/community/tests/unit_tests/retrievers/test_bedrock.py +++ b/libs/community/tests/unit_tests/retrievers/test_bedrock.py @@ -33,6 +33,11 @@ def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None: amazon_retriever.create_client({}) +def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None: + ls_params = amazon_retriever._get_ls_params() + assert ls_params == {"ls_retriever_name": "amazonknowledgebases"} + + def test_get_relevant_documents( amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock ) -> None: diff --git a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py index 75c040d126b..508bf0ac1ba 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -633,6 +633,28 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No assert len(search_result) == 0 +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_standard_params() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorstore = default_databricks_vector_search(index) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + "ls_embedding_provider": "FakeEmbeddingsWithDimension", + } + + index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) + vectorstore = default_databricks_vector_search(index) + retriever = vectorstore.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "DatabricksVectorSearch", + } + + @pytest.mark.requires("databricks", "databricks.vector_search") @pytest.mark.parametrize( "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] diff --git a/libs/community/tests/unit_tests/vectorstores/test_faiss.py b/libs/community/tests/unit_tests/vectorstores/test_faiss.py index 144f5fbb419..99b4ba6e699 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_faiss.py +++ b/libs/community/tests/unit_tests/vectorstores/test_faiss.py @@ -49,6 +49,15 @@ def test_faiss() -> None: output = docsearch.similarity_search("foo", k=1) assert output == [Document(page_content="foo")] + # Retriever standard params + retriever = docsearch.as_retriever() + ls_params = retriever._get_ls_params() + assert ls_params == { + "ls_retriever_name": "vectorstore", + "ls_vector_store_provider": "FAISS", + "ls_embedding_provider": "FakeEmbeddings", + } + @pytest.mark.requires("faiss") async def test_faiss_afrom_texts() -> None: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 6bbf34072b7..c4971fd4962 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -26,6 +26,8 @@ from abc import ABC, abstractmethod from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing_extensions import TypedDict + from langchain_core._api import deprecated from langchain_core.documents import Document from langchain_core.load.dump import dumpd @@ -50,6 +52,19 @@ RetrieverLike = Runnable[RetrieverInput, RetrieverOutput] RetrieverOutputLike = Runnable[Any, RetrieverOutput] +class LangSmithRetrieverParams(TypedDict, total=False): + """LangSmith parameters for tracing.""" + + ls_retriever_name: str + """Retriever name.""" + ls_vector_store_provider: Optional[str] + """Vector store provider.""" + ls_embedding_provider: Optional[str] + """Embedding provider.""" + ls_embedding_model: Optional[str] + """Embedding model.""" + + class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): """Abstract base class for a Document retrieval system. @@ -167,6 +182,19 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0 ) + def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: + """Get standard params for tracing.""" + + default_retriever_name = self.get_name() + if default_retriever_name.startswith("Retriever"): + default_retriever_name = default_retriever_name[9:] + elif default_retriever_name.endswith("Retriever"): + default_retriever_name = default_retriever_name[:-9] + default_retriever_name = default_retriever_name.lower() + + ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name) + return ls_params + def invoke( self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> List[Document]: @@ -191,13 +219,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): from langchain_core.callbacks.manager import CallbackManager config = ensure_config(config) + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } callback_manager = CallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags"), local_tags=self.tags, - inheritable_metadata=config.get("metadata"), + inheritable_metadata=inheritable_metadata, local_metadata=self.metadata, ) run_manager = callback_manager.on_retriever_start( @@ -250,13 +282,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC): from langchain_core.callbacks.manager import AsyncCallbackManager config = ensure_config(config) + inheritable_metadata = { + **(config.get("metadata") or {}), + **self._get_ls_params(**kwargs), + } callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), None, verbose=kwargs.get("verbose", False), inheritable_tags=config.get("tags"), local_tags=self.tags, - inheritable_metadata=config.get("metadata"), + inheritable_metadata=inheritable_metadata, local_metadata=self.metadata, ) run_manager = await callback_manager.on_retriever_start( diff --git a/libs/core/langchain_core/vectorstores/base.py b/libs/core/langchain_core/vectorstores/base.py index 806549f709b..701b0a4626c 100644 --- a/libs/core/langchain_core/vectorstores/base.py +++ b/libs/core/langchain_core/vectorstores/base.py @@ -44,7 +44,7 @@ from typing import ( from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import Field, root_validator -from langchain_core.retrievers import BaseRetriever +from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: @@ -1014,6 +1014,25 @@ class VectorStoreRetriever(BaseRetriever): ) return values + def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams: + """Get standard params for tracing.""" + + ls_params = super()._get_ls_params(**kwargs) + ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__ + + if self.vectorstore.embeddings: + ls_params["ls_embedding_provider"] = ( + self.vectorstore.embeddings.__class__.__name__ + ) + elif hasattr(self.vectorstore, "embedding") and isinstance( + self.vectorstore.embedding, Embeddings + ): + ls_params["ls_embedding_provider"] = ( + self.vectorstore.embedding.__class__.__name__ + ) + + return ls_params + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: