Compare commits

...

3 Commits

Author SHA1 Message Date
William Fu-Hinthorn
12dd2c4526 lint 2023-08-06 15:18:40 -07:00
William Fu-Hinthorn
d9c292a2be update 2023-08-06 15:14:42 -07:00
William Fu-Hinthorn
47bb316024 Add example ID 2023-08-06 13:37:59 -07:00
7 changed files with 60 additions and 1 deletions

View File

@@ -227,6 +227,7 @@ def trace_as_chain_group(
cm = CallbackManager.configure(
inheritable_callbacks=cb,
inheritable_tags=tags,
example_id=example_id,
)
run_manager = cm.on_chain_start({"name": group_name}, {})
@@ -1273,6 +1274,7 @@ class CallbackManager(BaseCallbackManager):
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
example_id: Optional[Union[str, UUID]] = None,
) -> CallbackManager:
"""Configure the callback manager.
@@ -1290,6 +1292,7 @@ class CallbackManager(BaseCallbackManager):
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
example_id (Optional[UUID], optional): The example ID. Defaults to None.
Returns:
CallbackManager: The configured callback manager.
@@ -1303,6 +1306,7 @@ class CallbackManager(BaseCallbackManager):
local_tags,
inheritable_metadata,
local_metadata,
example_id,
)
@@ -1565,6 +1569,7 @@ class AsyncCallbackManager(BaseCallbackManager):
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
example_id: Optional[Union[str, UUID]] = None,
) -> AsyncCallbackManager:
"""Configure the async callback manager.
@@ -1582,6 +1587,7 @@ class AsyncCallbackManager(BaseCallbackManager):
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
example_id (Optional[UUID], optional): The ID of the example. Defaults to None.
Returns:
AsyncCallbackManager: The configured async callback manager.
@@ -1595,6 +1601,7 @@ class AsyncCallbackManager(BaseCallbackManager):
local_tags,
inheritable_metadata,
local_metadata,
example_id=example_id,
)
@@ -1627,6 +1634,7 @@ def _configure(
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
example_id: Optional[Union[str, UUID]] = None,
) -> T:
"""Configure the callback manager.
@@ -1644,10 +1652,12 @@ def _configure(
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
example_id (Optional[UUID], optional): The example ID. Defaults to None.
Returns:
T: The configured callback manager.
"""
example_id = UUID(example_id) if isinstance(example_id, str) else example_id
callback_manager = callback_manager_cls(handlers=[])
if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
@@ -1744,10 +1754,16 @@ def _configure(
for handler in callback_manager.handlers
):
if tracer_v2:
if example_id:
# This can get ugly since we don't manage the un-setting
# of the example_id
tracer_v2.example_id = example_id
callback_manager.add_handler(tracer_v2, True)
else:
try:
handler = LangChainTracer(project_name=tracer_project)
handler = LangChainTracer(
project_name=tracer_project, example_id=example_id
)
callback_manager.add_handler(handler, True)
except Exception as e:
logger.warning(
@@ -1756,6 +1772,13 @@ def _configure(
" unset the LANGCHAIN_TRACING_V2 environment variables.",
e,
)
elif tracing_v2_enabled_ and example_id:
# This can get ugly since we don't manage the un-setting
# of the example_id
for handler in callback_manager.handlers:
if isinstance(handler, LangChainTracer):
handler.example_id = example_id
break
if open_ai is not None and not any(
isinstance(handler, OpenAICallbackHandler)
for handler in callback_manager.handlers

View File

@@ -6,6 +6,7 @@ import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import yaml
from pydantic import Field, root_validator, validator
@@ -206,6 +207,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
example_id: Optional[UUID] = None,
) -> Dict[str, Any]:
"""Execute the chain.
@@ -241,6 +243,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
self.tags,
metadata,
self.metadata,
example_id=example_id,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
@@ -273,6 +276,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
example_id: Optional[UUID] = None,
) -> Dict[str, Any]:
"""Asynchronously execute the chain.
@@ -294,6 +298,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
example_id: Optional UUID of the example being processed. Defaults to None.
Returns:
A dict of named outputs. Should contain all outputs specified in
@@ -308,6 +313,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
self.tags,
metadata,
self.metadata,
example_id=example_id,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(

View File

@@ -158,6 +158,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
self.tags,
config.get("metadata"),
self.metadata,
example_id=config.get("example_id"),
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options

View File

@@ -336,6 +336,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
config.get("metadata"),
self.metadata,
example_id=config.get("example_id"),
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
@@ -383,6 +384,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
config.get("metadata"),
self.metadata,
example_id=kwargs.get("example_id"),
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
@@ -542,6 +544,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
meta,
self.metadata,
example_id=kwargs.get("example_id"),
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
@@ -556,6 +559,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
example_id=kwargs.get("example_id"),
)
] * len(prompts)
@@ -691,6 +695,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
meta,
self.metadata,
example_id=kwargs.get("example_id"),
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
@@ -705,6 +710,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
example_id=kwargs.get("example_id"),
)
] * len(prompts)

View File

@@ -174,6 +174,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
example_id=kwargs.get("example_id"),
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
@@ -230,6 +231,7 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
example_id=kwargs.get("example_id"),
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),

View File

@@ -22,6 +22,7 @@ from typing import (
Union,
cast,
)
from uuid import UUID
from pydantic import Field
@@ -63,6 +64,11 @@ class RunnableConfig(TypedDict, total=False):
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
example_id: Optional[UUID]
"""
Example ID to associate with this call and sub-calls.
"""
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
@@ -175,6 +181,7 @@ class Runnable(Generic[Input, Output], ABC):
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
example_id=config.get("example_id"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self),
@@ -231,6 +238,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = callback_manager.on_chain_start(
@@ -274,6 +282,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = await callback_manager.on_chain_start(
@@ -323,6 +332,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
for config in configs
]
@@ -388,6 +398,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
for config in configs
]
@@ -505,6 +516,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = callback_manager.on_chain_start(
@@ -544,6 +556,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = await callback_manager.on_chain_start(
@@ -588,6 +601,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
for config in configs
]
@@ -644,6 +658,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
for config in configs
]
@@ -700,6 +715,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = callback_manager.on_chain_start(
@@ -763,6 +779,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = await callback_manager.on_chain_start(
@@ -852,6 +869,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
@@ -894,6 +912,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
example_id=config.get("example_id"),
)
# start the root run
run_manager = await callback_manager.on_chain_start(

View File

@@ -306,6 +306,7 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
self.tags,
metadata,
self.metadata,
example_id=kwargs.get("example_id"),
)
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
@@ -379,6 +380,7 @@ class BaseTool(BaseModel, Runnable[Union[str, Dict], Any], metaclass=ToolMetacla
self.tags,
metadata,
self.metadata,
example_id=kwargs.get("example_id"),
)
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start(