mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
Nc/5oct/runnable release (#11428)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
@@ -49,7 +49,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
"""Start a trace for a run."""
|
||||
if run.parent_run_id:
|
||||
parent_run = self.run_map[str(run.parent_run_id)]
|
||||
parent_run = self.run_map.get(str(run.parent_run_id))
|
||||
if parent_run:
|
||||
self._add_child_run(parent_run, run)
|
||||
parent_run.child_execution_order = max(
|
||||
|
@@ -54,13 +54,12 @@ class RunState(TypedDict):
|
||||
streamed_output: List[Any]
|
||||
"""List of output chunks streamed by Runnable.stream()"""
|
||||
final_output: Optional[Any]
|
||||
"""Final output of the run, usually the result of aggregating streamed_output.
|
||||
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
|
||||
Only available after the run has finished successfully."""
|
||||
|
||||
logs: Dict[str, LogEntry]
|
||||
"""List of sub-runs contained in this run, if any, in the order they were started.
|
||||
If filters were supplied, this list will contain only the runs that matched the
|
||||
filters."""
|
||||
"""Map of run names to sub-runs. If filters were supplied, this list will
|
||||
contain only the runs that matched the filters."""
|
||||
|
||||
|
||||
class RunLogPatch:
|
||||
@@ -74,7 +73,7 @@ class RunLogPatch:
|
||||
def __init__(self, *ops: Dict[str, Any]) -> None:
|
||||
self.ops = list(ops)
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(None, ops)
|
||||
@@ -102,7 +101,7 @@ class RunLog(RunLogPatch):
|
||||
super().__init__(*ops)
|
||||
self.state = state
|
||||
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
|
||||
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
|
||||
if type(other) == RunLogPatch:
|
||||
ops = self.ops + other.ops
|
||||
state = jsonpatch.apply_patch(self.state, other.ops)
|
||||
|
@@ -26,16 +26,17 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import get_args
|
||||
from typing_extensions import Literal, get_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.callbacks.tracers.log_stream import RunLogPatch
|
||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||
from langchain.schema.runnable.fallbacks import (
|
||||
RunnableWithFallbacks as RunnableWithFallbacksT,
|
||||
)
|
||||
@@ -290,11 +291,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
yield await self.ainvoke(input, config, **kwargs)
|
||||
|
||||
async def astream_log(
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[True] = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
@@ -303,6 +306,39 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[RunLogPatch]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: Literal[False],
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[RunLog]:
|
||||
...
|
||||
|
||||
async def astream_log(
|
||||
self,
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
diff: bool = True,
|
||||
include_names: Optional[Sequence[str]] = None,
|
||||
include_types: Optional[Sequence[str]] = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_names: Optional[Sequence[str]] = None,
|
||||
exclude_types: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
|
||||
"""
|
||||
Stream all output from a runnable, as reported to the callback system.
|
||||
This includes all inner runs of LLMs, Retrievers, Tools, etc.
|
||||
@@ -317,6 +353,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.tracers.log_stream import (
|
||||
LogStreamCallbackHandler,
|
||||
RunLog,
|
||||
RunLogPatch,
|
||||
)
|
||||
|
||||
@@ -370,8 +407,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
try:
|
||||
# Yield each chunk from the output stream
|
||||
async for log in stream:
|
||||
yield log
|
||||
if diff:
|
||||
async for log in stream:
|
||||
yield log
|
||||
else:
|
||||
state = RunLog(state=None) # type: ignore[arg-type]
|
||||
async for log in stream:
|
||||
state = state + log
|
||||
yield state
|
||||
finally:
|
||||
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
||||
try:
|
||||
|
@@ -171,7 +171,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(input, dict)
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
return {
|
||||
**input,
|
||||
**self.mapper.invoke(input, config, **kwargs),
|
||||
@@ -183,7 +185,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
assert isinstance(input, dict)
|
||||
assert isinstance(
|
||||
input, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
return {
|
||||
**input,
|
||||
**await self.mapper.ainvoke(input, config, **kwargs),
|
||||
@@ -204,10 +208,16 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
# get executor to start map output stream in background
|
||||
with get_executor_for_config(config or {}) as executor:
|
||||
# start map output stream
|
||||
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
|
||||
first_map_chunk_future = executor.submit(
|
||||
next,
|
||||
map_output, # type: ignore
|
||||
None,
|
||||
)
|
||||
# consume passthrough stream
|
||||
for chunk in for_passthrough:
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(
|
||||
chunk, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
@@ -233,11 +243,13 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
map_output = self.mapper.atransform(for_map, config, **kwargs)
|
||||
# start map output stream
|
||||
first_map_chunk_task: asyncio.Task = asyncio.create_task(
|
||||
py_anext(map_output), # type: ignore[arg-type]
|
||||
py_anext(map_output, None), # type: ignore[arg-type]
|
||||
)
|
||||
# consume passthrough stream
|
||||
async for chunk in for_passthrough:
|
||||
assert isinstance(chunk, dict)
|
||||
assert isinstance(
|
||||
chunk, dict
|
||||
), "The input to RunnablePassthrough.assign() must be a dict."
|
||||
# remove mapper keys from passthrough chunk, to be overwritten by map output
|
||||
filtered = AddableDict(
|
||||
{k: v for k, v in chunk.items() if k not in mapper_keys}
|
||||
|
@@ -1260,6 +1260,42 @@ async def test_prompt() -> None:
|
||||
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
|
||||
]
|
||||
|
||||
stream_log_state = [
|
||||
part
|
||||
async for part in prompt.astream_log(
|
||||
{"question": "What is your name?"}, diff=False
|
||||
)
|
||||
]
|
||||
|
||||
# remove random id
|
||||
stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
|
||||
stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
|
||||
stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
# assert output with diff=False matches output with diff=True
|
||||
assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]
|
||||
assert stream_log_state[-1] == RunLog(
|
||||
*[op for chunk in stream_log for op in chunk.ops],
|
||||
state={
|
||||
"final_output": ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
),
|
||||
"id": "00000000-0000-0000-0000-000000000000",
|
||||
"logs": {},
|
||||
"streamed_output": [
|
||||
ChatPromptValue(
|
||||
messages=[
|
||||
SystemMessage(content="You are a nice assistant."),
|
||||
HumanMessage(content="What is your name?"),
|
||||
]
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_template_params() -> None:
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
|
Reference in New Issue
Block a user