This commit is contained in:
Eugene Yurtsev
2024-01-17 21:49:14 -05:00
parent 8db4a75f04
commit 12e247fb73
3 changed files with 95 additions and 117 deletions

View File

@@ -7,7 +7,6 @@ import threading
from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
from contextvars import copy_context
from copy import deepcopy
from functools import wraps
from itertools import groupby, tee
from operator import itemgetter
@@ -37,7 +36,7 @@ from typing import (
from typing_extensions import Literal, get_args
from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd, dumps
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import BaseConfig, BaseModel, Field, create_model
from langchain_core.runnables.config import (
@@ -87,10 +86,8 @@ if TYPE_CHECKING:
)
from langchain_core.tracers.log_stream import (
LogEntry,
LogStreamCallbackHandler,
RunLog,
RunLogPatch,
_astream_log_implementation,
)
from langchain_core.tracers.root_listeners import Listener
@@ -683,7 +680,10 @@ class Runnable(Generic[Input, Output], ABC):
_schema_format="original",
)
async for item in _astream_log_implementation(
# Mypy isn't resolving the overloads here
# Likely an issue b/c `self` is being passed through
# and it's can't map it to Runnable[Input,Output]?
async for item in _astream_log_implementation( # type: ignore
self,
input,
config,
@@ -735,20 +735,23 @@ class Runnable(Generic[Input, Output], ABC):
| event | name | chunk | input | output |
|----------------------|------------------|---------------------------------|-----------------------------------------------|-------------------------------------------------|
| on_retriever_start | [retriever name] | | {"query": "hello"} | |
| on_retriever_chunk | [retriever name] | {documents: [...]} | | |
| on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} |
| on_chat_model_start | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | |
| on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | |
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} |
| on_llm_start | [model name] | | {'input': 'hello'} | |
| on_llm_stream | [model name] | 'Hello' | | |
| on_llm_end | [model name] | | 'Hello human!' |
| on_chain_start | format_docs | | | |
| on_chain_stream | format_docs | "hello world!, goodbye world!" | | |
| on_chain_end | format_docs | | [Document(...)] | "hello world!, goodbye world!" |
| on_tool_start | some_tool | | {"x": 1, "y": "2"} | |
| on_tool_stream | some_tool | {"x": 1, "y": "2"} | | |
| on_tool_end | some_tool | | | {"x": 1, "y": "2"} |
| on_retriever_start | [retriever name] | | {"query": "hello"} | |
| on_retriever_chunk | [retriever name] | {documents: [...]} | | |
| on_retriever_end | [retriever name] | | {"query": "hello"} | {documents: [...]} |
| on_prompt_start | [template_name] | | {"question": "hello"} | |
| on_prompt_end | [template_name] | | {"question": "hello"} | ChatPromptValue(messages: [SystemMessage, ...]) |
| on_chat_model_start | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | |
| on_chat_model_stream | [model name] | AIMessageChunk(content="hello") | | |
| on_chat_model_end | [model name] | | {"messages": [[SystemMessage, HumanMessage]]} | {"generations": [...], "llm_output": None, ...} |
```python
def format_docs(docs: List[Document]) -> str:

View File

@@ -11,11 +11,13 @@ from typing import (
Dict,
List,
Literal,
NotRequired,
Optional,
Sequence,
TypedDict,
TypeVar,
Union,
overload,
)
from uuid import UUID
@@ -26,7 +28,7 @@ from langchain_core.load import dumps
from langchain_core.load.load import load
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from langchain_core.runnables.utils import Output
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run
@@ -51,23 +53,12 @@ class LogEntry(TypedDict):
"""List of LLM tokens streamed by this run, if applicable."""
streamed_output: List[Any]
"""List of output chunks streamed by this run, if available."""
inputs: Optional[Any]
"""Inputs to this run.
The inputs will be a dictionary, matching
"""
inputs: NotRequired[Optional[Any]]
"""Inputs to this run. Not available currently via astream_log."""
final_output: Optional[Any]
"""Final output of this run.
Only available after the run has finished successfully.
"""Final output of this run.
Schema:
Retriever:
- Sequence[Document]
"""
Only available after the run has finished successfully."""
end_time: Optional[str]
"""ISO-8601 timestamp of when the run ended.
Only available after the run has finished."""
@@ -323,25 +314,30 @@ class LogStreamCallbackHandler(BaseTracer):
run.name if count == 1 else f"{run.name}:{count}"
)
entry = LogEntry(
id=str(run.id),
name=run.name,
type=run.run_type,
tags=run.tags or [],
metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output=[],
streamed_output_str=[],
final_output=None,
end_time=None,
)
if self._schema_format == "streaming_events":
# If using streaming events let's add inputs as well
entry["inputs"] = _get_standardized_inputs(run, self._schema_format)
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._key_map_by_run_id[run.id]}",
"value": LogEntry(
id=str(run.id),
name=run.name,
type=run.run_type,
tags=run.tags or [],
metadata=(run.extra or {}).get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output=[],
streamed_output_str=[],
inputs=_get_standardized_inputs(run, self._schema_format),
final_output=None,
end_time=None,
),
"value": entry,
}
)
)
@@ -439,11 +435,14 @@ def _get_standardized_inputs(
invocation using named arguments.
A None means that the input is not yet known!
"""
inputs = load(run.inputs)
if schema_format == "original":
# Return the old schema, without standardizing anything
return inputs
raise NotImplementedError(
"Do not assign inputs with original schema drop the key for now."
"When inputs are added to astream_log they should be added with "
"standardized schema for streaming events."
)
inputs = load(run.inputs)
if run.run_type in {"retriever", "llm", "chat_model"}:
return inputs
@@ -489,8 +488,36 @@ def _get_standardized_outputs(
return None
@overload
def _astream_log_implementation(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
stream: LogStreamCallbackHandler,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]:
...
@overload
def _astream_log_implementation(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
stream: LogStreamCallbackHandler,
diff: Literal[False],
with_streamed_output_list: bool = True,
**kwargs: Any,
) -> AsyncIterator[RunLog]:
...
async def _astream_log_implementation(
runnable: Runnable,
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
@@ -498,7 +525,12 @@ async def _astream_log_implementation(
diff: bool = True,
with_streamed_output_list: bool = True,
**kwargs: Any,
):
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
"""Implementation of astream_log for a given runnable.
The implementation has been factored out (at least temporarily) as both
astream_log and astream_events relies on it.
"""
import jsonpatch # type: ignore[import]
from langchain_core.callbacks.base import BaseCallbackManager

View File

@@ -1639,7 +1639,6 @@ async def test_prompt() -> None:
),
"id": "00000000-0000-0000-0000-000000000000",
"logs": {},
"name": "ChatPromptTemplate",
"streamed_output": [
ChatPromptValue(
messages=[
@@ -1649,9 +1648,12 @@ async def test_prompt() -> None:
)
],
"type": "prompt",
"name": "ChatPromptTemplate",
},
)
# nested inside trace_with_chain_group
async with atrace_as_chain_group("a_group") as manager:
stream_log_nested = [
part
@@ -2092,10 +2094,10 @@ async def test_prompt_with_llm(
"op": "replace",
"path": "",
"value": {
"final_output": None,
"logs": {},
"name": "RunnableSequence",
"final_output": None,
"streamed_output": [],
"name": "RunnableSequence",
"type": "chain",
},
}
@@ -2107,7 +2109,6 @@ async def test_prompt_with_llm(
"value": {
"end_time": None,
"final_output": None,
"inputs": {"question": "What is your name?"},
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000+00:00",
@@ -2119,11 +2120,6 @@ async def test_prompt_with_llm(
}
),
RunLogPatch(
{
"op": "replace",
"path": "/logs/ChatPromptTemplate/inputs",
"value": {"question": "What is your name?"},
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
@@ -2147,12 +2143,6 @@ async def test_prompt_with_llm(
"value": {
"end_time": None,
"final_output": None,
"inputs": {
"prompts": [
"System: You are a nice assistant.\n"
"Human: What is your name?"
]
},
"metadata": {},
"name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
@@ -2164,16 +2154,6 @@ async def test_prompt_with_llm(
}
),
RunLogPatch(
{
"op": "replace",
"path": "/logs/FakeListLLM/inputs",
"value": {
"prompts": [
"System: You are a nice assistant.\n"
"Human: What is your name?"
]
},
},
{
"op": "add",
"path": "/logs/FakeListLLM/final_output",
@@ -2318,10 +2298,10 @@ async def test_prompt_with_llm_parser(
"op": "replace",
"path": "",
"value": {
"final_output": None,
"logs": {},
"name": "RunnableSequence",
"final_output": None,
"streamed_output": [],
"name": "RunnableSequence",
"type": "chain",
},
}
@@ -2333,7 +2313,6 @@ async def test_prompt_with_llm_parser(
"value": {
"end_time": None,
"final_output": None,
"inputs": {"question": "What is your name?"},
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000+00:00",
@@ -2345,11 +2324,6 @@ async def test_prompt_with_llm_parser(
}
),
RunLogPatch(
{
"op": "replace",
"path": "/logs/ChatPromptTemplate/inputs",
"value": {"question": "What is your name?"},
},
{
"op": "add",
"path": "/logs/ChatPromptTemplate/final_output",
@@ -2373,12 +2347,6 @@ async def test_prompt_with_llm_parser(
"value": {
"end_time": None,
"final_output": None,
"inputs": {
"prompts": [
"System: You are a nice assistant.\n"
"Human: What is your name?"
]
},
"metadata": {},
"name": "FakeStreamingListLLM",
"start_time": "2023-01-01T00:00:00.000+00:00",
@@ -2390,16 +2358,6 @@ async def test_prompt_with_llm_parser(
}
),
RunLogPatch(
{
"op": "replace",
"path": "/logs/FakeStreamingListLLM/inputs",
"value": {
"prompts": [
"System: You are a nice assistant.\n"
"Human: What is your name?"
]
},
},
{
"op": "add",
"path": "/logs/FakeStreamingListLLM/final_output",
@@ -2430,7 +2388,6 @@ async def test_prompt_with_llm_parser(
"value": {
"end_time": None,
"final_output": None,
"inputs": None,
"metadata": {},
"name": "CommaSeparatedListOutputParser",
"start_time": "2023-01-01T00:00:00.000+00:00",
@@ -2475,15 +2432,10 @@ async def test_prompt_with_llm_parser(
{"op": "add", "path": "/final_output/2", "value": "cat"},
),
RunLogPatch(
{
"op": "replace",
"path": "/logs/CommaSeparatedListOutputParser/inputs",
"value": "bear, dog, cat",
},
{
"op": "add",
"path": "/logs/CommaSeparatedListOutputParser/final_output",
"value": ["bear", "dog", "cat"],
"value": {"output": ["bear", "dog", "cat"]},
},
{
"op": "add",
@@ -2492,7 +2444,6 @@ async def test_prompt_with_llm_parser(
},
),
]
assert stream_log == expected
@@ -2558,7 +2509,7 @@ async def test_stream_log_lists() -> None:
):
del op["value"]["id"]
expected = [
assert stream_log == [
RunLogPatch(
{
"op": "replace",
@@ -2566,8 +2517,8 @@ async def test_stream_log_lists() -> None:
"value": {
"final_output": None,
"logs": {},
"name": "list_producer",
"streamed_output": [],
"name": "list_producer",
"type": "chain",
},
}
@@ -2590,8 +2541,6 @@ async def test_stream_log_lists() -> None:
),
]
assert stream_log == expected
state = add(stream_log)
assert isinstance(state, RunLog)
@@ -3358,21 +3307,15 @@ async def test_map_astream() -> None:
final_state += chunk
final_state = cast(RunLog, final_state)
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 4
assert (
final_state.state["logs"]["ChatPromptTemplate"]["name"] == "ChatPromptTemplate"
)
assert final_state.state["logs"]["ChatPromptTemplate"][
"final_output"
] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
assert final_state.state["logs"]["ChatPromptTemplate"]["final_output"] == (
prompt.invoke({"question": "What is your name?"})
)
assert (
final_state.state["logs"]["RunnableParallel<chat,llm,passthrough>"]["name"]
== "RunnableParallel<chat,llm,passthrough>"
@@ -5209,7 +5152,7 @@ async def test_astream_log_deep_copies() -> None:
assert state == {
"final_output": 2,
"logs": {},
"streamed_output": [2],
"name": "add_one",
"type": "chain",
"streamed_output": [2],
}