mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Allow astream_log to be used inside atrace_as_chain_group (#12558)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
8e88ba16a8
commit
7897483819
@ -1832,6 +1832,18 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
|
||||
self.parent_run_manager = parent_run_manager
|
||||
self.ended = False
|
||||
|
||||
def copy(self) -> AsyncCallbackManagerForChainGroup:
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
parent_run_manager=self.parent_run_manager,
|
||||
)
|
||||
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
|
@ -157,12 +157,13 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
self.receive_stream = receive_stream
|
||||
self._key_map_by_run_id: Dict[UUID, str] = {}
|
||||
self._counter_map_by_name: Dict[str, int] = defaultdict(int)
|
||||
self.root_id: Optional[UUID] = None
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def include_run(self, run: Run) -> bool:
|
||||
if run.parent_run_id is None:
|
||||
if run.id == self.root_id:
|
||||
return False
|
||||
|
||||
run_tags = run.tags or []
|
||||
@ -199,7 +200,8 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Start a run."""
|
||||
if run.parent_run_id is None:
|
||||
if self.root_id is None:
|
||||
self.root_id = run.id
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
@ -273,7 +275,7 @@ class LogStreamCallbackHandler(BaseTracer):
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if run.parent_run_id is None:
|
||||
if run.id == self.root_id:
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
|
@ -463,7 +463,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
config["callbacks"] = callbacks + [stream]
|
||||
elif isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.copy()
|
||||
callbacks.inheritable_handlers.append(stream)
|
||||
callbacks.add_handler(stream, inherit=True)
|
||||
config["callbacks"] = callbacks
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -19,7 +19,7 @@ from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import Callbacks, collect_runs
|
||||
from langchain.callbacks.manager import Callbacks, atrace_as_chain_group, collect_runs
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
@ -1442,6 +1442,40 @@ async def test_prompt() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
# nested inside trace_with_chain_group
|
||||
|
||||
async with atrace_as_chain_group("a_group") as manager:
|
||||
stream_log_nested = [
|
||||
part
|
||||
async for part in prompt.astream_log(
|
||||
{"question": "What is your name?"}, config={"callbacks": manager}
|
||||
)
|
||||
]
|
||||
|
||||
assert len(stream_log_nested[0].ops) == 1
|
||||
assert stream_log_nested[0].ops[0]["op"] == "replace"
|
||||
assert stream_log_nested[0].ops[0]["path"] == ""
|
||||
assert stream_log_nested[0].ops[0]["value"]["logs"] == {}
|
||||
assert stream_log_nested[0].ops[0]["value"]["final_output"] is None
|
||||
assert stream_log_nested[0].ops[0]["value"]["streamed_output"] == []
|
||||
assert isinstance(stream_log_nested[0].ops[0]["value"]["id"], str)
|
||||
|
||||
assert stream_log_nested[1:] == [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
}
|
||||
),
|
||||
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||
]
|
||||
|
||||
|
||||
def test_prompt_template_params() -> None:
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
|
Loading…
Reference in New Issue
Block a user