mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
core[patch]: add standard tracing params for retrievers (#25240)
This commit is contained in:
@@ -26,6 +26,8 @@ from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.load.dump import dumpd
|
||||
@@ -50,6 +52,19 @@ RetrieverLike = Runnable[RetrieverInput, RetrieverOutput]
|
||||
RetrieverOutputLike = Runnable[Any, RetrieverOutput]
|
||||
|
||||
|
||||
class LangSmithRetrieverParams(TypedDict, total=False):
|
||||
"""LangSmith parameters for tracing."""
|
||||
|
||||
ls_retriever_name: str
|
||||
"""Retriever name."""
|
||||
ls_vector_store_provider: Optional[str]
|
||||
"""Vector store provider."""
|
||||
ls_embedding_provider: Optional[str]
|
||||
"""Embedding provider."""
|
||||
ls_embedding_model: Optional[str]
|
||||
"""Embedding model."""
|
||||
|
||||
|
||||
class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
"""Abstract base class for a Document retrieval system.
|
||||
|
||||
@@ -167,6 +182,19 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
|
||||
)
|
||||
|
||||
def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
|
||||
"""Get standard params for tracing."""
|
||||
|
||||
default_retriever_name = self.get_name()
|
||||
if default_retriever_name.startswith("Retriever"):
|
||||
default_retriever_name = default_retriever_name[9:]
|
||||
elif default_retriever_name.endswith("Retriever"):
|
||||
default_retriever_name = default_retriever_name[:-9]
|
||||
default_retriever_name = default_retriever_name.lower()
|
||||
|
||||
ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
|
||||
return ls_params
|
||||
|
||||
def invoke(
|
||||
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
@@ -191,13 +219,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
inheritable_metadata=inheritable_metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
@@ -250,13 +282,17 @@ class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
config = ensure_config(config)
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(**kwargs),
|
||||
}
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=config.get("tags"),
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
inheritable_metadata=inheritable_metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
|
||||
Reference in New Issue
Block a user