mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 16:39:20 +00:00
support adding custom metadata to runs (#7120)
- [x] wire up tools - [x] wire up retrievers - [x] add integration test <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
30d8d1d3d0
commit
4c1c05c2c7
@ -147,6 +147,7 @@ class CallbackManagerMixin:
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
@ -159,6 +160,7 @@ class CallbackManagerMixin:
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
@ -174,6 +176,7 @@ class CallbackManagerMixin:
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when Retriever starts running."""
|
||||
@ -186,6 +189,7 @@ class CallbackManagerMixin:
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
@ -198,6 +202,7 @@ class CallbackManagerMixin:
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
@ -268,6 +273,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when LLM starts running."""
|
||||
@ -280,6 +286,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run when a chat model starts running."""
|
||||
@ -328,6 +335,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when chain starts running."""
|
||||
@ -362,6 +370,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when tool starts running."""
|
||||
@ -429,6 +438,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run on retriever start."""
|
||||
@ -467,6 +477,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager."""
|
||||
self.handlers: List[BaseCallbackHandler] = handlers
|
||||
@ -476,6 +488,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||
self.tags = tags or []
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
@ -518,3 +532,13 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
for tag in tags:
|
||||
self.tags.remove(tag)
|
||||
self.inheritable_tags.remove(tag)
|
||||
|
||||
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
|
||||
self.metadata.update(metadata)
|
||||
if inherit:
|
||||
self.inheritable_metadata.update(metadata)
|
||||
|
||||
def remove_metadata(self, keys: List[str]) -> None:
|
||||
for key in keys:
|
||||
self.metadata.pop(key)
|
||||
self.inheritable_metadata.pop(key)
|
||||
|
@ -383,6 +383,8 @@ class BaseRunManager(RunManagerMixin):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Initialize the run manager.
|
||||
|
||||
@ -395,6 +397,8 @@ class BaseRunManager(RunManagerMixin):
|
||||
Defaults to None.
|
||||
tags (Optional[List[str]]): The list of tags.
|
||||
inheritable_tags (Optional[List[str]]): The list of inheritable tags.
|
||||
metadata (Optional[Dict[str, Any]]): The metadata.
|
||||
inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata.
|
||||
"""
|
||||
self.run_id = run_id
|
||||
self.handlers = handlers
|
||||
@ -402,6 +406,8 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.parent_run_id = parent_run_id
|
||||
self.tags = tags or []
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||
@ -416,6 +422,8 @@ class BaseRunManager(RunManagerMixin):
|
||||
inheritable_handlers=[],
|
||||
tags=[],
|
||||
inheritable_tags=[],
|
||||
metadata={},
|
||||
inheritable_metadata={},
|
||||
)
|
||||
|
||||
|
||||
@ -447,6 +455,28 @@ class RunManager(BaseRunManager):
|
||||
)
|
||||
|
||||
|
||||
class ParentRunManager(RunManager):
|
||||
"""Sync Parent Run Manager."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag for the child callback manager.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
|
||||
class AsyncRunManager(BaseRunManager):
|
||||
"""Async Run Manager."""
|
||||
|
||||
@ -475,6 +505,28 @@ class AsyncRunManager(BaseRunManager):
|
||||
)
|
||||
|
||||
|
||||
class AsyncParentRunManager(AsyncRunManager):
|
||||
"""Async Parent Run Manager."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag for the child callback manager.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
|
||||
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
||||
"""Callback manager for LLM run."""
|
||||
|
||||
@ -601,26 +653,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
||||
"""Callback manager for chain run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag for the child callback manager.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
@ -700,26 +735,9 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
"""Async callback manager for chain run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag for the child callback manager.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||
"""Run when chain ends running.
|
||||
|
||||
@ -799,26 +817,9 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
||||
"""Callback manager for tool run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag for the child callback manager.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
def on_tool_end(
|
||||
self,
|
||||
output: str,
|
||||
@ -862,26 +863,9 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
"""Async callback manager for tool run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager.
|
||||
|
||||
Args:
|
||||
tag (str, optional): The tag to add to the child
|
||||
callback manager. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The child callback manager.
|
||||
"""
|
||||
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
@ -921,18 +905,9 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||
)
|
||||
|
||||
|
||||
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
|
||||
class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
||||
"""Callback manager for retriever run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
def on_retriever_end(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -969,20 +944,11 @@ class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
|
||||
|
||||
|
||||
class AsyncCallbackManagerForRetrieverRun(
|
||||
AsyncRunManager,
|
||||
AsyncParentRunManager,
|
||||
RetrieverManagerMixin,
|
||||
):
|
||||
"""Async callback manager for retriever run."""
|
||||
|
||||
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||
"""Get a child callback manager."""
|
||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], False)
|
||||
return manager
|
||||
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
@ -1048,6 +1014,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1059,6 +1026,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1094,6 +1063,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1105,6 +1075,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1139,6 +1111,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1149,6 +1122,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
def on_tool_start(
|
||||
@ -1182,6 +1157,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1192,6 +1168,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
def on_retriever_start(
|
||||
@ -1215,6 +1193,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1225,6 +1204,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1235,6 +1216,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@ -1248,6 +1231,10 @@ class CallbackManager(BaseCallbackManager):
|
||||
Defaults to None.
|
||||
local_tags (Optional[List[str]], optional): The local tags.
|
||||
Defaults to None.
|
||||
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
|
||||
metadata. Defaults to None.
|
||||
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
CallbackManager: The configured callback manager.
|
||||
@ -1259,6 +1246,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
verbose,
|
||||
inheritable_tags,
|
||||
local_tags,
|
||||
inheritable_metadata,
|
||||
local_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -1305,6 +1294,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@ -1317,6 +1307,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1358,6 +1350,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@ -1370,6 +1363,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1406,6 +1401,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1416,6 +1412,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_tool_start(
|
||||
@ -1451,6 +1449,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1461,6 +1460,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
async def on_retriever_start(
|
||||
@ -1484,6 +1485,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -1494,6 +1496,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1504,6 +1508,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the async callback manager.
|
||||
|
||||
@ -1517,6 +1523,10 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
Defaults to None.
|
||||
local_tags (Optional[List[str]], optional): The local tags.
|
||||
Defaults to None.
|
||||
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
|
||||
metadata. Defaults to None.
|
||||
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncCallbackManager: The configured async callback manager.
|
||||
@ -1528,6 +1538,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
verbose,
|
||||
inheritable_tags,
|
||||
local_tags,
|
||||
inheritable_metadata,
|
||||
local_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -1558,6 +1570,8 @@ def _configure(
|
||||
verbose: bool = False,
|
||||
inheritable_tags: Optional[List[str]] = None,
|
||||
local_tags: Optional[List[str]] = None,
|
||||
inheritable_metadata: Optional[Dict[str, Any]] = None,
|
||||
local_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@ -1571,6 +1585,10 @@ def _configure(
|
||||
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
|
||||
Defaults to None.
|
||||
local_tags (Optional[List[str]], optional): The local tags. Defaults to None.
|
||||
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
|
||||
metadata. Defaults to None.
|
||||
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
T: The configured callback manager.
|
||||
@ -1590,6 +1608,8 @@ def _configure(
|
||||
parent_run_id=inheritable_callbacks.parent_run_id,
|
||||
tags=inheritable_callbacks.tags,
|
||||
inheritable_tags=inheritable_callbacks.inheritable_tags,
|
||||
metadata=inheritable_callbacks.metadata,
|
||||
inheritable_metadata=inheritable_callbacks.inheritable_metadata,
|
||||
)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
@ -1601,6 +1621,9 @@ def _configure(
|
||||
if inheritable_tags or local_tags:
|
||||
callback_manager.add_tags(inheritable_tags or [])
|
||||
callback_manager.add_tags(local_tags or [], False)
|
||||
if inheritable_metadata or local_metadata:
|
||||
callback_manager.add_metadata(inheritable_metadata or {})
|
||||
callback_manager.add_metadata(local_metadata or {}, False)
|
||||
|
||||
tracer = tracing_callback_var.get()
|
||||
wandb_tracer = wandb_tracing_callback_var.get()
|
||||
|
@ -89,12 +89,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
llm_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@ -186,12 +189,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chain_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@ -253,12 +259,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a tool run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
tool_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
@ -317,12 +326,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
retrieval_run = Run(
|
||||
id=run_id,
|
||||
name="Retriever",
|
||||
|
@ -70,12 +70,15 @@ class LangChainTracer(BaseTracer):
|
||||
run_id: UUID,
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for an LLM run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
start_time = datetime.utcnow()
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
chat_model_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
|
@ -54,6 +54,12 @@ class Chain(Serializable, ABC):
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the chain. Defaults to None
|
||||
This metadata will be associated with each call to this chain,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a chain with its use case.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -130,6 +136,7 @@ class Chain(Serializable, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
@ -143,12 +150,20 @@ class Chain(Serializable, ABC):
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Optional list of tags associated with the chain. Defaults to None
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
@ -179,6 +194,7 @@ class Chain(Serializable, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
@ -192,12 +208,20 @@ class Chain(Serializable, ABC):
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Optional list of tags associated with the chain. Defaults to None
|
||||
metadata: Optional metadata associated with the chain. Defaults to None
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
@ -278,6 +302,7 @@ class Chain(Serializable, ABC):
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
@ -287,10 +312,14 @@ class Chain(Serializable, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return self(args[0], callbacks=callbacks, tags=tags)[_output_key]
|
||||
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if kwargs and not args:
|
||||
return self(kwargs, callbacks=callbacks, tags=tags)[_output_key]
|
||||
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
|
||||
_output_key
|
||||
]
|
||||
|
||||
if not kwargs and not args:
|
||||
raise ValueError(
|
||||
@ -308,6 +337,7 @@ class Chain(Serializable, ABC):
|
||||
*args: Any,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||
@ -320,14 +350,18 @@ class Chain(Serializable, ABC):
|
||||
if args and not kwargs:
|
||||
if len(args) != 1:
|
||||
raise ValueError("`run` supports only one positional argument.")
|
||||
return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
|
||||
self.output_keys[0]
|
||||
]
|
||||
return (
|
||||
await self.acall(
|
||||
args[0], callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
if kwargs and not args:
|
||||
return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
|
||||
self.output_keys[0]
|
||||
]
|
||||
return (
|
||||
await self.acall(
|
||||
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
|
||||
)
|
||||
)[self.output_keys[0]]
|
||||
|
||||
raise ValueError(
|
||||
f"`run` supported with either positional arguments or keyword arguments"
|
||||
|
@ -40,6 +40,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
|
||||
@root_validator()
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
@ -86,6 +88,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -98,6 +101,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
run_managers = callback_manager.on_chat_model_start(
|
||||
dumpd(self), messages, invocation_params=params, options=options
|
||||
@ -139,6 +144,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Top Level call"""
|
||||
@ -151,6 +157,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
|
||||
run_managers = await callback_manager.on_chat_model_start(
|
||||
|
@ -244,6 +244,7 @@ The following is the expected answer. Use this to measure correctness:
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
@ -257,6 +258,8 @@ The following is the expected answer. Use this to measure correctness:
|
||||
chain will be returned. Defaults to False.
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
tags: Tags to add to the chain run.
|
||||
metadata: Metadata to add to the chain run.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
|
@ -80,6 +80,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -190,6 +192,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -209,7 +212,13 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
@ -293,6 +302,7 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
@ -307,7 +317,13 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
) = get_prompts(params, prompts)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
self.verbose,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
@ -350,6 +366,9 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
@ -360,7 +379,14 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
"`generate` instead."
|
||||
)
|
||||
return (
|
||||
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
|
||||
self.generate(
|
||||
[prompt],
|
||||
stop=stop,
|
||||
callbacks=callbacks,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
.generations[0][0]
|
||||
.text
|
||||
)
|
||||
@ -370,11 +396,19 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Check Cache and run the LLM on the given prompt and input."""
|
||||
result = await self.agenerate(
|
||||
[prompt], stop=stop, callbacks=callbacks, **kwargs
|
||||
[prompt],
|
||||
stop=stop,
|
||||
callbacks=callbacks,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
return result.generations[0][0].text
|
||||
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, List
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -55,6 +55,20 @@ class BaseRetriever(Serializable, ABC):
|
||||
|
||||
_new_arg_supported: bool = False
|
||||
_expects_other_args: bool = False
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a retriever with its
|
||||
use case.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@ -117,19 +131,37 @@ class BaseRetriever(Serializable, ABC):
|
||||
"""
|
||||
|
||||
def get_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
@ -155,19 +187,37 @@ class BaseRetriever(Serializable, ABC):
|
||||
return result
|
||||
|
||||
async def aget_relevant_documents(
|
||||
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
callbacks: Callbacks = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Asynchronously get documents relevant to a query.
|
||||
Args:
|
||||
query: string to find relevant documents for
|
||||
callbacks: Callback manager or list of callbacks
|
||||
tags: Optional list of tags associated with the retriever. Defaults to None
|
||||
These tags will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
metadata: Optional metadata associated with the retriever. Defaults to None
|
||||
This metadata will be associated with each call to this retriever,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, None, verbose=kwargs.get("verbose", False)
|
||||
callbacks,
|
||||
None,
|
||||
verbose=kwargs.get("verbose", False),
|
||||
inheritable_tags=tags,
|
||||
local_tags=self.tags,
|
||||
inheritable_metadata=metadata,
|
||||
local_metadata=self.metadata,
|
||||
)
|
||||
run_manager = await callback_manager.on_retriever_start(
|
||||
dumpd(self),
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@ -153,6 +153,18 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
"""Callbacks to be called during tool execution."""
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""Deprecated. Please use callbacks instead."""
|
||||
tags: Optional[List[str]] = None
|
||||
"""Optional list of tags associated with the tool. Defaults to None
|
||||
These tags will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
"""Optional metadata associated with the tool. Defaults to None
|
||||
This metadata will be associated with each call to this tool,
|
||||
and passed as arguments to the handlers defined in `callbacks`.
|
||||
You can use these to eg identify a specific instance of a tool with its use case.
|
||||
"""
|
||||
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
@ -246,6 +258,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
@ -255,7 +270,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, verbose=verbose_
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
# TODO: maybe also pass through run_manager is _run supports kwargs
|
||||
new_arg_supported = signature(self._run).parameters.get("run_manager")
|
||||
@ -310,6 +331,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
start_color: Optional[str] = "green",
|
||||
color: Optional[str] = "green",
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool asynchronously."""
|
||||
@ -319,7 +343,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
|
||||
else:
|
||||
verbose_ = self.verbose
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks, self.callbacks, verbose=verbose_
|
||||
callbacks,
|
||||
self.callbacks,
|
||||
verbose_,
|
||||
tags,
|
||||
self.tags,
|
||||
metadata,
|
||||
self.metadata,
|
||||
)
|
||||
new_arg_supported = signature(self._arun).parameters.get("run_manager")
|
||||
run_manager = await callback_manager.on_tool_start(
|
||||
|
@ -181,6 +181,40 @@ def test_tracing_v2_chain_with_tags() -> None:
|
||||
chain.run("what is the meaning of life", tags=["a-tag"])
|
||||
|
||||
|
||||
def test_tracing_v2_agent_with_metadata() -> None:
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0)
|
||||
chat = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
chat_agent = initialize_agent(
|
||||
tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
chat_agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracing_v2_async_agent_with_metadata() -> None:
|
||||
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
||||
llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"})
|
||||
chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"})
|
||||
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
chat_agent = initialize_agent(
|
||||
async_tools,
|
||||
chat,
|
||||
agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
await agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
await chat_agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
|
||||
|
||||
|
||||
def test_trace_as_group() -> None:
|
||||
llm = OpenAI(temperature=0.9)
|
||||
prompt = PromptTemplate(
|
||||
|
Loading…
Reference in New Issue
Block a user