Nc/5oct/runnable release (#11428)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos
2023-10-05 14:27:50 +01:00
committed by GitHub
parent 58b7a3ba16
commit 1e59c44d36
6 changed files with 641 additions and 65 deletions

View File

@@ -49,7 +49,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
if run.parent_run_id:
parent_run = self.run_map[str(run.parent_run_id)]
parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run:
self._add_child_run(parent_run, run)
parent_run.child_execution_order = max(

View File

@@ -54,13 +54,12 @@ class RunState(TypedDict):
streamed_output: List[Any]
"""List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating streamed_output.
"""Final output of the run, usually the result of aggregating (`+`) streamed_output.
Only available after the run has finished successfully."""
logs: Dict[str, LogEntry]
"""List of sub-runs contained in this run, if any, in the order they were started.
If filters were supplied, this list will contain only the runs that matched the
filters."""
"""Map of run names to sub-runs. If filters were supplied, this list will
contain only the runs that matched the filters."""
class RunLogPatch:
@@ -74,7 +73,7 @@ class RunLogPatch:
def __init__(self, *ops: Dict[str, Any]) -> None:
self.ops = list(ops)
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(None, ops)
@@ -102,7 +101,7 @@ class RunLog(RunLogPatch):
super().__init__(*ops)
self.state = state
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLog:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(self.state, other.ops)

View File

@@ -26,16 +26,17 @@ from typing import (
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import get_args
from typing_extensions import Literal, get_args
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.schema.runnable.fallbacks import (
RunnableWithFallbacks as RunnableWithFallbacksT,
)
@@ -290,11 +291,13 @@ class Runnable(Generic[Input, Output], ABC):
"""
yield await self.ainvoke(input, config, **kwargs)
async def astream_log(
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[True] = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
@@ -303,6 +306,39 @@ class Runnable(Generic[Input, Output], ABC):
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[RunLogPatch]:
...
@overload
def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: Literal[False],
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[RunLog]:
...
async def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
diff: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Optional[Any],
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
"""
Stream all output from a runnable, as reported to the callback system.
This includes all inner runs of LLMs, Retrievers, Tools, etc.
@@ -317,6 +353,7 @@ class Runnable(Generic[Input, Output], ABC):
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
RunLogPatch,
)
@@ -370,8 +407,14 @@ class Runnable(Generic[Input, Output], ABC):
try:
# Yield each chunk from the output stream
async for log in stream:
yield log
if diff:
async for log in stream:
yield log
else:
state = RunLog(state=None) # type: ignore[arg-type]
async for log in stream:
state = state + log
yield state
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:

View File

@@ -171,7 +171,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**self.mapper.invoke(input, config, **kwargs),
@@ -183,7 +185,9 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Dict[str, Any]:
assert isinstance(input, dict)
assert isinstance(
input, dict
), "The input to RunnablePassthrough.assign() must be a dict."
return {
**input,
**await self.mapper.ainvoke(input, config, **kwargs),
@@ -204,10 +208,16 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
# get executor to start map output stream in background
with get_executor_for_config(config or {}) as executor:
# start map output stream
first_map_chunk_future = executor.submit(next, map_output) # type: ignore
first_map_chunk_future = executor.submit(
next,
map_output, # type: ignore
None,
)
# consume passthrough stream
for chunk in for_passthrough:
assert isinstance(chunk, dict)
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}
@@ -233,11 +243,13 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
map_output = self.mapper.atransform(for_map, config, **kwargs)
# start map output stream
first_map_chunk_task: asyncio.Task = asyncio.create_task(
py_anext(map_output), # type: ignore[arg-type]
py_anext(map_output, None), # type: ignore[arg-type]
)
# consume passthrough stream
async for chunk in for_passthrough:
assert isinstance(chunk, dict)
assert isinstance(
chunk, dict
), "The input to RunnablePassthrough.assign() must be a dict."
# remove mapper keys from passthrough chunk, to be overwritten by map output
filtered = AddableDict(
{k: v for k, v in chunk.items() if k not in mapper_keys}

View File

@@ -1260,6 +1260,42 @@ async def test_prompt() -> None:
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
]
stream_log_state = [
part
async for part in prompt.astream_log(
{"question": "What is your name?"}, diff=False
)
]
# remove random id
stream_log[0].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
stream_log_state[-1].ops[0]["value"]["id"] = "00000000-0000-0000-0000-000000000000"
stream_log_state[-1].state["id"] = "00000000-0000-0000-0000-000000000000"
# assert output with diff=False matches output with diff=True
assert stream_log_state[-1].ops == [op for chunk in stream_log for op in chunk.ops]
assert stream_log_state[-1] == RunLog(
*[op for chunk in stream_log for op in chunk.ops],
state={
"final_output": ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
"id": "00000000-0000-0000-0000-000000000000",
"logs": {},
"streamed_output": [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
],
},
)
def test_prompt_template_params() -> None:
prompt = ChatPromptTemplate.from_template(