mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
Add Runnable.astream_log()
(#10374)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
a1ade48e8f
commit
fcb5aba9f0
@ -1,7 +1,7 @@
|
||||
"""Base callback handler that can be used to handle callbacks in langchain."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@ -502,6 +502,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
||||
"""Run on retriever error."""
|
||||
|
||||
|
||||
T = TypeVar("T", bound="BaseCallbackManager")
|
||||
|
||||
|
||||
class BaseCallbackManager(CallbackManagerMixin):
|
||||
"""Base callback manager that handles callbacks from LangChain."""
|
||||
|
||||
@ -527,6 +530,18 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
|
||||
def copy(self: T) -> T:
|
||||
"""Copy the callback manager."""
|
||||
return self.__class__(
|
||||
handlers=self.handlers,
|
||||
inheritable_handlers=self.inheritable_handlers,
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_async(self) -> bool:
|
||||
"""Whether the callback manager is async."""
|
||||
|
@ -58,6 +58,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
else:
|
||||
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
|
||||
self.run_map[str(run.id)] = run
|
||||
self._on_run_create(run)
|
||||
|
||||
def _end_trace(self, run: Run) -> None:
|
||||
"""End a trace for a run."""
|
||||
@ -74,6 +75,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
):
|
||||
parent_run.child_execution_order = run.child_execution_order
|
||||
self.run_map.pop(str(run.id))
|
||||
self._on_run_update(run)
|
||||
|
||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||
"""Get the execution order for a run."""
|
||||
@ -101,7 +103,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
@ -123,6 +125,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
self._on_llm_start(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
@ -132,7 +135,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_new_token callback.")
|
||||
@ -151,6 +154,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"kwargs": event_kwargs,
|
||||
},
|
||||
)
|
||||
self._on_llm_new_token(llm_run, token, chunk)
|
||||
return llm_run
|
||||
|
||||
def on_retry(
|
||||
self,
|
||||
@ -158,7 +163,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retry callback.")
|
||||
run_id_ = str(run_id)
|
||||
@ -186,8 +191,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"kwargs": retry_d,
|
||||
},
|
||||
)
|
||||
return llm_run
|
||||
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_end callback.")
|
||||
@ -208,6 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
llm_run.events.append({"name": "end", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
self._on_llm_end(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_llm_error(
|
||||
self,
|
||||
@ -215,7 +222,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Handle an error for an LLM run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_llm_error callback.")
|
||||
@ -229,6 +236,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
llm_run.events.append({"name": "error", "time": llm_run.end_time})
|
||||
self._end_trace(llm_run)
|
||||
self._on_chain_error(llm_run)
|
||||
return llm_run
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
@ -242,7 +250,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_type: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Start a trace for a chain run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
@ -266,6 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
self._on_chain_start(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_chain_end(
|
||||
self,
|
||||
@ -274,7 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
run_id: UUID,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""End a trace for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
@ -291,6 +300,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_end(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_chain_error(
|
||||
self,
|
||||
@ -299,7 +309,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Handle an error for a chain run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
@ -314,6 +324,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
|
||||
self._end_trace(chain_run)
|
||||
self._on_chain_error(chain_run)
|
||||
return chain_run
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@ -325,7 +336,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
@ -348,8 +359,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
self._on_tool_start(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
|
||||
"""End a trace for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||
@ -362,6 +374,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tool_run.events.append({"name": "end", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_end(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_tool_error(
|
||||
self,
|
||||
@ -369,7 +382,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Handle an error for a tool run."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||
@ -382,6 +395,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tool_run.events.append({"name": "error", "time": tool_run.end_time})
|
||||
self._end_trace(tool_run)
|
||||
self._on_tool_error(tool_run)
|
||||
return tool_run
|
||||
|
||||
def on_retriever_start(
|
||||
self,
|
||||
@ -393,7 +407,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Run when Retriever starts running."""
|
||||
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||
execution_order = self._get_execution_order(parent_run_id_)
|
||||
@ -417,6 +431,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def on_retriever_error(
|
||||
self,
|
||||
@ -424,7 +439,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
*,
|
||||
run_id: UUID,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Run when Retriever errors."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
@ -437,10 +452,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_error(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def on_retriever_end(
|
||||
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
|
||||
) -> None:
|
||||
) -> Run:
|
||||
"""Run when Retriever ends running."""
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
@ -452,6 +468,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
|
||||
self._end_trace(retrieval_run)
|
||||
self._on_retriever_end(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> BaseTracer:
|
||||
"""Deepcopy the tracer."""
|
||||
@ -461,9 +478,23 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
"""Copy the tracer."""
|
||||
return self
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Process a run upon creation."""
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
"""Process a run upon update."""
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon start."""
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
|
||||
|
289
libs/langchain/langchain/callbacks/tracers/log_stream.py
Normal file
289
libs/langchain/langchain/callbacks/tracers/log_stream.py
Normal file
@ -0,0 +1,289 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
import jsonpatch
|
||||
from anyio import create_memory_object_stream
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
|
||||
class LogEntry(TypedDict):
|
||||
id: str
|
||||
"""ID of the sub-run."""
|
||||
name: str
|
||||
"""Name of the object being run."""
|
||||
type: str
|
||||
"""Type of the object being run, eg. prompt, chain, llm, etc."""
|
||||
tags: List[str]
|
||||
"""List of tags for the run."""
|
||||
metadata: Dict[str, Any]
|
||||
"""Key-value pairs of metadata for the run."""
|
||||
start_time: str
|
||||
"""ISO-8601 timestamp of when the run started."""
|
||||
|
||||
streamed_output_str: List[str]
|
||||
"""List of LLM tokens streamed by this run, if applicable."""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of this run.
|
||||
Only available after the run has finished successfully."""
|
||||
end_time: Optional[str]
|
||||
"""ISO-8601 timestamp of when the run ended.
|
||||
Only available after the run has finished."""
|
||||
|
||||
|
||||
class RunState(TypedDict):
|
||||
id: str
|
||||
"""ID of the run."""
|
||||
streamed_output: List[Any]
|
||||
"""List of output chunks streamed by Runnable.stream()"""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of the run, usually the result of aggregating streamed_output.
|
||||
Only available after the run has finished successfully."""
|
||||
|
||||
logs: list[LogEntry]
|
||||
"""List of sub-runs contained in this run, if any, in the order they were started.
|
||||
If filters were supplied, this list will contain only the runs that matched the
|
||||
filters."""
|
||||
|
||||
|
||||
class RunLogPatch:
|
||||
ops: List[Dict[str, Any]]
|
||||
"""List of jsonpatch operations, which describe how to create the run state
|
||||
from an empty dict. This is the minimal representation of the log, designed to
|
||||
be serialized as JSON and sent over the wire to reconstruct the log on the other
|
||||
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
|
||||
see https://jsonpatch.com for more information."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any]) -> None:
|
||||
self.ops = list(ops)
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(None, ops)
|
||||
return RunLog(*ops, state=state)
|
||||
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
return f"RunLogPatch(ops={pformat(self.ops)})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, RunLogPatch) and self.ops == other.ops
|
||||
|
||||
|
||||
class RunLog(RunLogPatch):
|
||||
state: RunState
|
||||
"""Current state of the log, obtained from applying all ops in sequence."""
|
||||
|
||||
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
|
||||
super().__init__(*ops)
|
||||
self.state = state
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(self.state, other.ops)
|
||||
return RunLog(*ops, state=state)
|
||||
|
||||
raise TypeError(
|
||||
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
from pprint import pformat
|
||||
|
||||
return f"RunLog(state={pformat(self.state)})"
|
||||
|
||||
|
||||
class LogStreamCallbackHandler(BaseTracer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
auto_close: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.auto_close = auto_close
|
||||
self.include_names = include_names
|
||||
self.include_types = include_types
|
||||
self.include_tags = include_tags
|
||||
self.exclude_names = exclude_names
|
||||
self.exclude_types = exclude_types
|
||||
self.exclude_tags = exclude_tags
|
||||
|
||||
send_stream, receive_stream = create_memory_object_stream(
|
||||
math.inf, item_type=RunLogPatch
|
||||
)
|
||||
self.lock = threading.Lock()
|
||||
self.send_stream = send_stream
|
||||
self.receive_stream = receive_stream
|
||||
self._index_map: Dict[UUID, int] = {}
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
|
||||
return self.receive_stream.__aiter__()
|
||||
|
||||
def include_run(self, run: Run) -> bool:
|
||||
if run.parent_run_id is None:
|
||||
return False
|
||||
|
||||
run_tags = run.tags or []
|
||||
|
||||
if (
|
||||
self.include_names is None
|
||||
and self.include_types is None
|
||||
and self.include_tags is None
|
||||
):
|
||||
include = True
|
||||
else:
|
||||
include = False
|
||||
|
||||
if self.include_names is not None:
|
||||
include = include or run.name in self.include_names
|
||||
if self.include_types is not None:
|
||||
include = include or run.run_type in self.include_types
|
||||
if self.include_tags is not None:
|
||||
include = include or any(tag in self.include_tags for tag in run_tags)
|
||||
|
||||
if self.exclude_names is not None:
|
||||
include = include and run.name not in self.exclude_names
|
||||
if self.exclude_types is not None:
|
||||
include = include and run.run_type not in self.exclude_types
|
||||
if self.exclude_tags is not None:
|
||||
include = include and all(tag not in self.exclude_tags for tag in run_tags)
|
||||
|
||||
return include
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
# This is a legacy method only called once for an entire run tree
|
||||
# therefore not useful here
|
||||
pass
|
||||
|
||||
def _on_run_create(self, run: Run) -> None:
|
||||
"""Start a run."""
|
||||
if run.parent_run_id is None:
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": RunState(
|
||||
id=run.id,
|
||||
streamed_output=[],
|
||||
final_output=None,
|
||||
logs=[],
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if not self.include_run(run):
|
||||
return
|
||||
|
||||
# Determine previous index, increment by 1
|
||||
with self.lock:
|
||||
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1
|
||||
|
||||
# Add the run to the stream
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{self._index_map[run.id]}",
|
||||
"value": LogEntry(
|
||||
id=str(run.id),
|
||||
name=run.name,
|
||||
type=run.run_type,
|
||||
tags=run.tags or [],
|
||||
metadata=run.extra.get("metadata", {}),
|
||||
start_time=run.start_time.isoformat(timespec="milliseconds"),
|
||||
streamed_output_str=[],
|
||||
final_output=None,
|
||||
end_time=None,
|
||||
),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def _on_run_update(self, run: Run) -> None:
|
||||
"""Finish a run."""
|
||||
try:
|
||||
index = self._index_map.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/final_output",
|
||||
"value": run.outputs,
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/end_time",
|
||||
"value": run.end_time.isoformat(timespec="milliseconds"),
|
||||
},
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if run.parent_run_id is None:
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": run.outputs,
|
||||
}
|
||||
)
|
||||
)
|
||||
if self.auto_close:
|
||||
self.send_stream.close()
|
||||
|
||||
def _on_llm_new_token(
|
||||
self,
|
||||
run: Run,
|
||||
token: str,
|
||||
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
|
||||
) -> None:
|
||||
"""Process new LLM token."""
|
||||
index = self._index_map.get(run.id)
|
||||
|
||||
if index is None:
|
||||
return
|
||||
|
||||
self.send_stream.send_nowait(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": f"/logs/{index}/streamed_output_str/-",
|
||||
"value": token,
|
||||
}
|
||||
)
|
||||
)
|
@ -34,6 +34,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.pydantic_v1 import Field
|
||||
@ -190,6 +192,89 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
yield await self.ainvoke(input, config, **kwargs)
|
||||
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[RunLogPatch]:
|
||||
"""
|
||||
Stream all output from a runnable, as reported to the callback system.
|
||||
This includes all inner runs of LLMs, Retrievers, Tools, etc.
|
||||
|
||||
Output is streamed as Log objects, which include a list of
|
||||
jsonpatch ops that describe how the state of the run has changed in each
|
||||
step, and the final state of the run.
|
||||
|
||||
The jsonpatch ops can be applied in order to construct state.
|
||||
"""
|
||||
|
||||
# Create a stream handler that will emit Log objects
|
||||
stream = LogStreamCallbackHandler(
|
||||
auto_close=False,
|
||||
include_names=include_names,
|
||||
include_types=include_types,
|
||||
include_tags=include_tags,
|
||||
exclude_names=exclude_names,
|
||||
exclude_types=exclude_types,
|
||||
exclude_tags=exclude_tags,
|
||||
)
|
||||
|
||||
# Assign the stream handler to the config
|
||||
config = config or {}
|
||||
callbacks = config.get("callbacks")
|
||||
if callbacks is None:
|
||||
config["callbacks"] = [stream]
|
||||
elif isinstance(callbacks, list):
|
||||
config["callbacks"] = callbacks + [stream]
|
||||
elif isinstance(callbacks, BaseCallbackManager):
|
||||
callbacks = callbacks.copy()
|
||||
callbacks.inheritable_handlers.append(stream)
|
||||
config["callbacks"] = callbacks
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected type for callbacks: {callbacks}."
|
||||
"Expected None, list or AsyncCallbackManager."
|
||||
)
|
||||
|
||||
# Call the runnable in streaming mode,
|
||||
# add each chunk to the output stream
|
||||
async def consume_astream() -> None:
|
||||
try:
|
||||
async for chunk in self.astream(input, config, **kwargs):
|
||||
await stream.send_stream.send(
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/streamed_output/-",
|
||||
"value": chunk,
|
||||
}
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await stream.send_stream.aclose()
|
||||
|
||||
# Start the runnable in a task, so we can start consuming output
|
||||
task = asyncio.create_task(consume_astream())
|
||||
|
||||
try:
|
||||
# Yield each chunk from the output stream
|
||||
async for log in stream:
|
||||
yield log
|
||||
finally:
|
||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def transform(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
|
16
libs/langchain/poetry.lock
generated
16
libs/langchain/poetry.lock
generated
@ -3610,6 +3610,20 @@ files = [
|
||||
[package.dependencies]
|
||||
attrs = ">=19.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "jsonpatch"
|
||||
version = "1.33"
|
||||
description = "Apply JSON-Patches (RFC 6902)"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
|
||||
files = [
|
||||
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
|
||||
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jsonpointer = ">=1.9"
|
||||
|
||||
[[package]]
|
||||
name = "jsonpointer"
|
||||
version = "2.4"
|
||||
@ -10608,4 +10622,4 @@ text-helpers = ["chardet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "11ce1c967a78f79a922b9bbbc1c00541703185e28c63b7a0a02aa5c562c36ee3"
|
||||
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"
|
||||
|
@ -129,6 +129,8 @@ markdownify = {version = "^0.11.6", optional = true}
|
||||
assemblyai = {version = "^0.17.0", optional = true}
|
||||
dashvector = {version = "^1.0.1", optional = true}
|
||||
sqlite-vss = {version = "^0.1.2", optional = true}
|
||||
anyio = "<4.0"
|
||||
jsonpatch = "^1.33"
|
||||
timescale-vector = {version = "^0.0.1", optional = true}
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@ -9,6 +9,7 @@ from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain.callbacks.manager import Callbacks, collect_runs
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
@ -368,6 +369,62 @@ async def test_prompt() -> None:
|
||||
part async for part in prompt.astream({"question": "What is your name?"})
|
||||
] == [expected]
|
||||
|
||||
stream_log = [
|
||||
part async for part in prompt.astream_log({"question": "What is your name?"})
|
||||
]
|
||||
|
||||
assert len(stream_log[0].ops) == 1
|
||||
assert stream_log[0].ops[0]["op"] == "replace"
|
||||
assert stream_log[0].ops[0]["path"] == ""
|
||||
assert stream_log[0].ops[0]["value"]["logs"] == []
|
||||
assert stream_log[0].ops[0]["value"]["final_output"] is None
|
||||
assert stream_log[0].ops[0]["value"]["streamed_output"] == []
|
||||
assert type(stream_log[0].ops[0]["value"]["id"]) == UUID
|
||||
|
||||
assert stream_log[1:] == [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "/final_output",
|
||||
"value": {
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||
"kwargs": {
|
||||
"messages": [
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"SystemMessage",
|
||||
],
|
||||
"kwargs": {"content": "You are a nice " "assistant."},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"HumanMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "What is your " "name?",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
]
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||
]
|
||||
|
||||
|
||||
def test_prompt_template_params() -> None:
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
@ -560,7 +617,7 @@ async def test_prompt_with_llm(
|
||||
mocker.stop(prompt_spy)
|
||||
mocker.stop(llm_spy)
|
||||
|
||||
# Test stream#
|
||||
# Test stream
|
||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||
tracer = FakeTracer()
|
||||
@ -578,6 +635,136 @@ async def test_prompt_with_llm(
|
||||
]
|
||||
)
|
||||
|
||||
prompt_spy.reset_mock()
|
||||
llm_spy.reset_mock()
|
||||
stream_log = [
|
||||
part async for part in chain.astream_log({"question": "What is your name?"})
|
||||
]
|
||||
|
||||
# remove ids from logs
|
||||
for part in stream_log:
|
||||
for op in part.ops:
|
||||
if (
|
||||
isinstance(op["value"], dict)
|
||||
and "id" in op["value"]
|
||||
and not isinstance(op["value"]["id"], list) # serialized lc id
|
||||
):
|
||||
del op["value"]["id"]
|
||||
|
||||
assert stream_log == [
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "",
|
||||
"value": {
|
||||
"logs": [],
|
||||
"final_output": None,
|
||||
"streamed_output": [],
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "ChatPromptTemplate",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:1"],
|
||||
"type": "prompt",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0/final_output",
|
||||
"value": {
|
||||
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
|
||||
"kwargs": {
|
||||
"messages": [
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"SystemMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "You are a nice " "assistant.",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
{
|
||||
"id": [
|
||||
"langchain",
|
||||
"schema",
|
||||
"messages",
|
||||
"HumanMessage",
|
||||
],
|
||||
"kwargs": {
|
||||
"additional_kwargs": {},
|
||||
"content": "What is your " "name?",
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
]
|
||||
},
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/0/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1",
|
||||
"value": {
|
||||
"end_time": None,
|
||||
"final_output": None,
|
||||
"metadata": {},
|
||||
"name": "FakeListLLM",
|
||||
"start_time": "2023-01-01T00:00:00.000",
|
||||
"streamed_output_str": [],
|
||||
"tags": ["seq:step:2"],
|
||||
"type": "llm",
|
||||
},
|
||||
}
|
||||
),
|
||||
RunLogPatch(
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1/final_output",
|
||||
"value": {
|
||||
"generations": [[{"generation_info": None, "text": "foo"}]],
|
||||
"llm_output": None,
|
||||
"run": None,
|
||||
},
|
||||
},
|
||||
{
|
||||
"op": "add",
|
||||
"path": "/logs/1/end_time",
|
||||
"value": "2023-01-01T00:00:00.000",
|
||||
},
|
||||
),
|
||||
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}),
|
||||
RunLogPatch(
|
||||
{"op": "replace", "path": "/final_output", "value": {"output": "foo"}}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
@ -1213,6 +1400,74 @@ async def test_map_astream() -> None:
|
||||
{"question": "What is your name?"}
|
||||
)
|
||||
|
||||
# Test astream_log state accumulation
|
||||
|
||||
final_state = None
|
||||
streamed_ops = []
|
||||
async for chunk in chain.astream_log({"question": "What is your name?"}):
|
||||
streamed_ops.extend(chunk.ops)
|
||||
if final_state is None:
|
||||
final_state = chunk
|
||||
else:
|
||||
final_state += chunk
|
||||
final_state = cast(RunLog, final_state)
|
||||
|
||||
assert final_state.state["final_output"] == final_value
|
||||
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||
assert isinstance(final_state.state["id"], UUID)
|
||||
assert len(final_state.ops) == len(streamed_ops)
|
||||
assert len(final_state.state["logs"]) == 5
|
||||
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||
prompt.invoke({"question": "What is your name?"})
|
||||
)
|
||||
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||
"FakeListChatModel",
|
||||
"FakeStreamingListLLM",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
# Test astream_log with include filters
|
||||
final_state = None
|
||||
async for chunk in chain.astream_log(
|
||||
{"question": "What is your name?"}, include_names=["FakeListChatModel"]
|
||||
):
|
||||
if final_state is None:
|
||||
final_state = chunk
|
||||
else:
|
||||
final_state += chunk
|
||||
final_state = cast(RunLog, final_state)
|
||||
|
||||
assert final_state.state["final_output"] == final_value
|
||||
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||
assert len(final_state.state["logs"]) == 1
|
||||
assert final_state.state["logs"][0]["name"] == "FakeListChatModel"
|
||||
|
||||
# Test astream_log with exclude filters
|
||||
final_state = None
|
||||
async for chunk in chain.astream_log(
|
||||
{"question": "What is your name?"}, exclude_names=["FakeListChatModel"]
|
||||
):
|
||||
if final_state is None:
|
||||
final_state = chunk
|
||||
else:
|
||||
final_state += chunk
|
||||
final_state = cast(RunLog, final_state)
|
||||
|
||||
assert final_state.state["final_output"] == final_value
|
||||
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
|
||||
assert len(final_state.state["logs"]) == 4
|
||||
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
|
||||
assert final_state.state["logs"][0]["final_output"] == dumpd(
|
||||
prompt.invoke({"question": "What is your name?"})
|
||||
)
|
||||
assert final_state.state["logs"][1]["name"] == "RunnableMap"
|
||||
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
|
||||
"FakeStreamingListLLM",
|
||||
"RunnablePassthrough",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_map_astream_iterator_input() -> None:
|
||||
|
@ -39,8 +39,10 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"PyYAML",
|
||||
"SQLAlchemy",
|
||||
"aiohttp",
|
||||
"anyio",
|
||||
"async-timeout",
|
||||
"dataclasses-json",
|
||||
"jsonpatch",
|
||||
"langsmith",
|
||||
"numexpr",
|
||||
"numpy",
|
||||
|
Loading…
Reference in New Issue
Block a user